forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 7
Mesh rcnn zhead #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Mesh rcnn zhead #19
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a26dcd9
Initial commit
johnzylstra 606c86e
Some weight loading files
johnzylstra 8ce1db2
Diff Tested
johnzylstra 30d1fd3
Config and minor fixes
johnzylstra 5b7d0ec
Cosmetic
johnzylstra f5a1e26
Aligning Function Arguments
johnzylstra a2df15e
test_z_head main
johnzylstra 9189cc7
Merge branch 'mesh_rcnn' into mesh_rcnn_zhead
johnzylstra 158fbba
Z head test params
johnzylstra File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
93 changes: 93 additions & 0 deletions
93
official/vision/beta/projects/mesh_rcnn/modeling/heads/test_z_head.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Tests for Mesh R-CNN Heads.""" | ||
|
|
||
| from absl.testing import parameterized | ||
| import tensorflow as tf | ||
|
|
||
| import z_head | ||
|
|
||
| @parameterized.named_parameters( | ||
| {'testcase_name': 'pix3d_params', | ||
| 'num_fc': 2, | ||
| 'fc_dim': 1024, | ||
| 'cls_agnostic': False, | ||
| 'num_classes': 9 | ||
| } | ||
| ) | ||
| class ZHeadTest(parameterized.TestCase, tf.test.TestCase): | ||
| '''Test for Mesh R-CNN Z head''' | ||
|
|
||
| def test_output_shape(self, | ||
| num_fc: int, | ||
| fc_dim: int, | ||
| cls_agnostic: bool, | ||
| num_classes: int): | ||
| '''Check that Z head output is of correct shape''' | ||
|
|
||
| (batch_size, height, width, channels) = (64, 14, 14, 256) | ||
| head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes) | ||
|
|
||
| test_input = tf.zeros((batch_size, height, width, channels)) | ||
| output = head(test_input) | ||
| expected_output = tf.zeros((batch_size, num_classes)) | ||
| self.assertAllEqual(output,expected_output) | ||
|
|
||
| def test_serialize_deserialize(self, | ||
| num_fc: int, | ||
| fc_dim: int, | ||
| cls_agnostic: bool, | ||
| num_classes: int): | ||
| """Create a network object that sets all of its config options.""" | ||
|
|
||
| (batch_size, height, width, channels) = (64, 14, 14, 256) | ||
| head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes) | ||
| test_input = tf.zeros((batch_size, height, width, channels)) | ||
| _ = head(test_input) | ||
|
|
||
| serialized = head.get_config() | ||
| deserialized = z_head.ZHead.from_config(serialized) | ||
|
|
||
| self.assertAllEqual(head.get_config(), deserialized.get_config()) | ||
|
|
||
| def test_gradient_pass_through(self, | ||
| num_fc: int, | ||
| fc_dim: int, | ||
| cls_agnostic: bool, | ||
| num_classes: int): | ||
| '''Check that gradients are not None''' | ||
|
|
||
| (batch_size, height, width, channels) = (64, 14, 14, 256) | ||
| head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes) | ||
|
|
||
| loss = tf.keras.losses.MeanSquaredError() | ||
| optimizer = tf.keras.optimizers.SGD() | ||
|
|
||
| init = tf.random_normal_initializer() | ||
| input_shape = (batch_size, height, width, channels) | ||
| test_input = tf.Variable(initial_value = init(shape=input_shape, dtype=tf.float32)) | ||
|
|
||
| output_shape = head(test_input).shape | ||
| test_output = tf.Variable(initial_value = init(shape=output_shape, dtype=tf.float32)) | ||
|
|
||
| with tf.GradientTape() as tape: | ||
| y_hat = head(test_input) | ||
| grad_loss = loss(y_hat, test_output) | ||
| grad = tape.gradient(grad_loss, head.trainable_variables) | ||
| optimizer.apply_gradients(zip(grad, head.trainable_variables)) | ||
|
|
||
| self.assertNotIn(None, grad) | ||
|
|
||
| if __name__ == '__main__': | ||
| tf.test.main() | ||
81 changes: 81 additions & 0 deletions
81
official/vision/beta/projects/mesh_rcnn/modeling/heads/z_head.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Mesh R-CNN Heads.""" | ||
|
|
||
| import tensorflow as tf | ||
|
|
||
| class ZHead(tf.keras.layers.Layer): | ||
| '''Depth prediction Z Head for Mesh R-CNN model''' | ||
| def __init__(self, | ||
| num_fc: int, | ||
| fc_dim: int, | ||
| cls_agnostic: bool, | ||
| num_classes: int, | ||
| **kwargs): | ||
| """ | ||
| Initialize Z-head | ||
| Args: | ||
| num_fc: number of fully connected layers | ||
| fc_dim: dimension of fully connected layers | ||
| cls_agnostic: | ||
| num_classes: number of prediction classes | ||
| """ | ||
| super(ZHead, self).__init__(**kwargs) | ||
|
|
||
| self._num_fc = num_fc | ||
| self._fc_dim = fc_dim | ||
| self._cls_agnostic = cls_agnostic | ||
| self._num_classes = num_classes | ||
|
|
||
| def build(self, input_shape: tf.TensorShape) -> None: | ||
| '''Build Z Head''' | ||
| self.flatten = tf.keras.layers.Flatten() | ||
|
|
||
| self.fcs = [] | ||
| for _ in range(self._num_fc): | ||
| fc_init = tf.keras.initializers.VarianceScaling(scale=1., | ||
| mode='fan_in', | ||
| distribution='truncated_normal') | ||
| layer = tf.keras.layers.Dense(self._fc_dim, | ||
| activation='relu', | ||
| kernel_initializer=fc_init) | ||
| self.fcs.append(layer) | ||
| num_z_reg_classes = 1 if self._cls_agnostic else self._num_classes | ||
| pred_init = tf.keras.initializers.RandomNormal(stddev=0.001) | ||
| self.z_pred = tf.keras.layers.Dense(num_z_reg_classes, | ||
| kernel_initializer=pred_init, | ||
| bias_initializer='zeros') | ||
|
|
||
| def call(self, features): | ||
| '''Forward pass of Z head''' | ||
| out = self.flatten(features) | ||
| for layer in self.fcs: | ||
| out = layer(out) | ||
| out = self.z_pred(out) | ||
| return out | ||
|
|
||
| def get_config(self): | ||
| """Get config dict of the ZHead layer.""" | ||
| config = dict( | ||
| num_fc = self._num_fc, | ||
| fc_dim = self._fc_dim, | ||
| cls_agnostic = self._cls_agnostic, | ||
| num_classes = self._num_classes | ||
| ) | ||
| return config | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config): | ||
| '''Initialize Z head from config''' | ||
| return cls(**config) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.