diff --git a/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py b/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py index 066744d4de6..838924dd0c8 100644 --- a/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py +++ b/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py @@ -18,6 +18,14 @@ from official.modeling import hyperparams # type: ignore +@dataclasses.dataclass +class ZHead(hyperparams.Config): + """Parameterization for the Mesh R-CNN Z Head.""" + num_fc: int = 2 + fc_dim: int = 1024 + cls_agnostic: bool = False + num_classes: int = 9 + @dataclasses.dataclass class VoxelHead(hyperparams.Config): """Parameterization for the Mesh R-CNN Voxel Branch Prediction Head.""" diff --git a/official/vision/beta/projects/mesh_rcnn/modeling/heads/test_z_head.py b/official/vision/beta/projects/mesh_rcnn/modeling/heads/test_z_head.py new file mode 100644 index 00000000000..b6769fb1aa9 --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/modeling/heads/test_z_head.py @@ -0,0 +1,93 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Mesh R-CNN Heads.""" + +from absl.testing import parameterized +import tensorflow as tf + +import z_head + +@parameterized.named_parameters( + {'testcase_name': 'pix3d_params', + 'num_fc': 2, + 'fc_dim': 1024, + 'cls_agnostic': False, + 'num_classes': 9 + } +) +class ZHeadTest(parameterized.TestCase, tf.test.TestCase): + '''Test for Mesh R-CNN Z head''' + + def test_output_shape(self, + num_fc: int, + fc_dim: int, + cls_agnostic: bool, + num_classes: int): + '''Check that Z head output is of correct shape''' + + (batch_size, height, width, channels) = (64, 14, 14, 256) + head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes) + + test_input = tf.zeros((batch_size, height, width, channels)) + output = head(test_input) + expected_output = tf.zeros((batch_size, num_classes)) + self.assertAllEqual(output,expected_output) + + def test_serialize_deserialize(self, + num_fc: int, + fc_dim: int, + cls_agnostic: bool, + num_classes: int): + """Create a network object that sets all of its config options.""" + + (batch_size, height, width, channels) = (64, 14, 14, 256) + head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes) + test_input = tf.zeros((batch_size, height, width, channels)) + _ = head(test_input) + + serialized = head.get_config() + deserialized = z_head.ZHead.from_config(serialized) + + self.assertAllEqual(head.get_config(), deserialized.get_config()) + + def test_gradient_pass_through(self, + num_fc: int, + fc_dim: int, + cls_agnostic: bool, + num_classes: int): + '''Check that gradients are not None''' + + (batch_size, height, width, channels) = (64, 14, 14, 256) + head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes) + + loss = tf.keras.losses.MeanSquaredError() + optimizer = tf.keras.optimizers.SGD() + + init = tf.random_normal_initializer() + input_shape = (batch_size, height, width, channels) + test_input = tf.Variable(initial_value = init(shape=input_shape, dtype=tf.float32)) + + output_shape = head(test_input).shape + test_output = tf.Variable(initial_value = init(shape=output_shape, dtype=tf.float32)) + + with tf.GradientTape() as tape: + y_hat = head(test_input) + grad_loss = loss(y_hat, test_output) + grad = tape.gradient(grad_loss, head.trainable_variables) + optimizer.apply_gradients(zip(grad, head.trainable_variables)) + + self.assertNotIn(None, grad) + +if __name__ == '__main__': + tf.test.main() diff --git a/official/vision/beta/projects/mesh_rcnn/modeling/heads/z_head.py b/official/vision/beta/projects/mesh_rcnn/modeling/heads/z_head.py new file mode 100644 index 00000000000..e257a082114 --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/modeling/heads/z_head.py @@ -0,0 +1,81 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mesh R-CNN Heads.""" + +import tensorflow as tf + +class ZHead(tf.keras.layers.Layer): + '''Depth prediction Z Head for Mesh R-CNN model''' + def __init__(self, + num_fc: int, + fc_dim: int, + cls_agnostic: bool, + num_classes: int, + **kwargs): + """ + Initialize Z-head + Args: + num_fc: number of fully connected layers + fc_dim: dimension of fully connected layers + cls_agnostic: + num_classes: number of prediction classes + """ + super(ZHead, self).__init__(**kwargs) + + self._num_fc = num_fc + self._fc_dim = fc_dim + self._cls_agnostic = cls_agnostic + self._num_classes = num_classes + + def build(self, input_shape: tf.TensorShape) -> None: + '''Build Z Head''' + self.flatten = tf.keras.layers.Flatten() + + self.fcs = [] + for _ in range(self._num_fc): + fc_init = tf.keras.initializers.VarianceScaling(scale=1., + mode='fan_in', + distribution='truncated_normal') + layer = tf.keras.layers.Dense(self._fc_dim, + activation='relu', + kernel_initializer=fc_init) + self.fcs.append(layer) + num_z_reg_classes = 1 if self._cls_agnostic else self._num_classes + pred_init = tf.keras.initializers.RandomNormal(stddev=0.001) + self.z_pred = tf.keras.layers.Dense(num_z_reg_classes, + kernel_initializer=pred_init, + bias_initializer='zeros') + + def call(self, features): + '''Forward pass of Z head''' + out = self.flatten(features) + for layer in self.fcs: + out = layer(out) + out = self.z_pred(out) + return out + + def get_config(self): + """Get config dict of the ZHead layer.""" + config = dict( + num_fc = self._num_fc, + fc_dim = self._fc_dim, + cls_agnostic = self._cls_agnostic, + num_classes = self._num_classes + ) + return config + + @classmethod + def from_config(cls, config): + '''Initialize Z head from config''' + return cls(**config) diff --git a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_classes.py b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_classes.py index 7bfc39df4cf..ae0b4a81b8d 100644 --- a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_classes.py +++ b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_classes.py @@ -2,7 +2,7 @@ from abc import ABC from dataclasses import dataclass, field -from typing import Dict +from typing import Dict, List import numpy as np import tensorflow as tf @@ -30,6 +30,38 @@ def load_weights(self, layer: tf.keras.layers.Layer) -> int: n_weights += tf.size(w) return n_weights +@dataclass +class ZHeadCFG(Config): + weights_dict: Dict = field(repr=False, default=None) + weights: List = field(repr=False, default=None) + + fc1_weights: np.array = field(repr=False, default=None) + fc1_bias: np.array = field(repr=False, default=None) + fc2_weights: np.array = field(repr=False, default=None) + fc2_bias: np.array = field(repr=False, default=None) + pred_weights: np.array = field(repr=False, default=None) + pred_bias: np.array = field(repr=False, default=None) + + def __post_init__(self): + self.fc1_weights = self.weights_dict['z_fc1']['weight'] + self.fc1_bias = self.weights_dict['z_fc1']['bias'] + self.fc2_weights = self.weights_dict['z_fc2']['weight'] + self.fc2_bias = self.weights_dict['z_fc2']['bias'] + self.pred_weights = self.weights_dict['z_pred']['weight'] + self.pred_bias = self.weights_dict['z_pred']['bias'] + + self.weights = [ + self.fc1_weights, + self.fc1_bias, + self.fc2_weights, + self.fc2_bias, + self.pred_weights, + self.pred_bias + ] + + def get_weights(self): + return self.weights + @dataclass class meshRefinementStageCFG(Config): weights_dict: Dict = field(repr=False, default=None) diff --git a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_data.py b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_data.py index 0c3d55c86cb..831a04d38ba 100644 --- a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_data.py +++ b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/config_data.py @@ -3,8 +3,17 @@ from typing import Dict from official.vision.beta.projects.mesh_rcnn.utils.weight_utils.config_classes import \ - meshRefinementStageCFG + meshRefinementStageCFG, ZHeadCFG +@dataclass +class ZHeadConfigData(): + weights_dict: Dict = field(repr=False, default=None) + + def get_cfg_list(self, name): + if name == "pix3d": + return ZHeadCFG(weights_dict=self.weights_dict) + else: + return [] @dataclass class MeshHeadConfigData(): @@ -17,6 +26,5 @@ def get_cfg_list(self, name): meshRefinementStageCFG(weights_dict=self.weights_dict['stages']['1']), meshRefinementStageCFG(weights_dict=self.weights_dict['stages']['2']), ] - else: return [] diff --git a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/load_weights.py b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/load_weights.py index 602d33ef902..1bbb6cb45a3 100644 --- a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/load_weights.py +++ b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/load_weights.py @@ -7,10 +7,12 @@ import tensorflow as tf from torch import load +from official.vision.beta.projects.mesh_rcnn.modeling.heads.z_head import \ + ZHead from official.vision.beta.projects.mesh_rcnn.modeling.layers.nn_blocks import \ MeshRefinementStage from official.vision.beta.projects.mesh_rcnn.utils.weight_utils.config_data import \ - MeshHeadConfigData + MeshHeadConfigData, ZHeadConfigData def pth_to_dict(pth_path): @@ -52,6 +54,39 @@ def pth_to_dict(pth_path): return weights_dict, n_read +def get_zhead_layer_cfgs(weights_dict, zhead_name): + """ Fetches the config classes for the z head. + This function generates a list of config classes corresponding to + each building block in the z head. + Args: + weights_dict: Dictionary that stores the z head weights. + zhead_name: String, indicating the desired z head configuration. + Returns: + A list containing the config classes of the mesh head building block. + """ + print("Fetching z head config classes for {}\n".format(zhead_name)) + cfgs = ZHeadConfigData(weights_dict).get_cfg_list(zhead_name) + return cfgs + + +def load_weights_zhead(zhead, weights_dict, zhead_name): + """ Loads the weights defined in the weights_dict into the z head. + This function loads the z head weights by first fetching the necesary + config classes for the backbone, then loads them in one by one for + each layer that has weights associated with it. + Args: + zhead: keras.Model + weights_dict: Dictionary that stores the zhead model weights. + zhead_name: String, indicating the desired zhead configuration. + Returns: + Number of weights loaded in. + """ + print("Loading z head weights\n") + cfgs = get_zhead_layer_cfgs(weights_dict, zhead_name) + cfgs.load_weights(zhead) + return + + def get_mesh_head_layer_cfgs(weights_dict, mesh_head_name): """ Fetches the config classes for the mesh head. This function generates a list of config classes corresponding to @@ -62,7 +97,6 @@ def get_mesh_head_layer_cfgs(weights_dict, mesh_head_name): Returns: A list containing the config classes of the mesh head building block. """ - print("Fetching mesh head config classes for {}\n".format(mesh_head_name)) cfgs = MeshHeadConfigData(weights_dict).get_cfg_list(mesh_head_name) return cfgs @@ -97,3 +131,4 @@ def load_weights_mesh_head(mesh_head, weights_dict, mesh_head_name): print("{} Weights have been loaded for {} / {} layers\n".format( n_weights_total, loaded_layers, i)) return n_weights_total + diff --git a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/test_load_weights.py b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/test_load_weights.py index d46d0e60d56..79fa6ccac7c 100644 --- a/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/test_load_weights.py +++ b/official/vision/beta/projects/mesh_rcnn/utils/weight_utils/test_load_weights.py @@ -1,16 +1,19 @@ + import numpy as np import tensorflow as tf from matplotlib import pyplot as plt from official.vision.beta.projects.mesh_rcnn.modeling.heads.mesh_head import \ MeshHead +from official.vision.beta.projects.mesh_rcnn.modeling.heads.z_head import \ + ZHead from official.vision.beta.projects.mesh_rcnn.ops.cubify import cubify from official.vision.beta.projects.mesh_rcnn.ops.mesh_ops import \ compute_mesh_shape from official.vision.beta.projects.mesh_rcnn.ops.visualize_mesh import \ visualize_mesh from official.vision.beta.projects.mesh_rcnn.utils.weight_utils.load_weights import ( - load_weights_mesh_head, pth_to_dict) + load_weights_mesh_head, load_weights_zhead, pth_to_dict) PTH_PATH = r"C:\ML\Weights\meshrcnn_R50.pth" BACKBONE_FEATURES = [ @@ -34,6 +37,26 @@ def print_layer_names(layers_dict, offset=0): print(" " * offset + k) print_layer_names(layers_dict[k], offset+2) +def test_load_zhead(): + weights_dict, n_read = pth_to_dict(PTH_PATH) + print(weights_dict.keys()) + print(weights_dict['roi_heads'].keys()) + print(weights_dict['roi_heads']['z_head'].keys()) + print(weights_dict['roi_heads']['z_head']['z_pred'].keys()) + print(weights_dict['roi_heads']['z_head']['z_pred']['weight'].shape) + + input_specs = dict( + num_fc = 2, + fc_dim = 1024, + cls_agnostic = False, + num_classes = 9 + ) + + zhead = ZHead.from_config(input_specs) + + n_weights = load_weights_zhead( + zhead, weights_dict['roi_heads']['z_head'], 'pix3d') + def test_load_mesh_refinement_branch(): weights_dict, n_read = pth_to_dict(PTH_PATH) @@ -107,3 +130,4 @@ def test_load_mesh_refinement_branch(): if __name__ == '__main__': test_load_mesh_refinement_branch() + test_load_zhead()