这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
from official.modeling import hyperparams # type: ignore


@dataclasses.dataclass
class ZHead(hyperparams.Config):
"""Parameterization for the Mesh R-CNN Z Head."""
num_fc: int = 2
fc_dim: int = 1024
cls_agnostic: bool = False
num_classes: int = 9

@dataclasses.dataclass
class VoxelHead(hyperparams.Config):
"""Parameterization for the Mesh R-CNN Voxel Branch Prediction Head."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Mesh R-CNN Heads."""

from absl.testing import parameterized
import tensorflow as tf

import z_head

@parameterized.named_parameters(
{'testcase_name': 'pix3d_params',
'num_fc': 2,
'fc_dim': 1024,
'cls_agnostic': False,
'num_classes': 9
}
)
class ZHeadTest(parameterized.TestCase, tf.test.TestCase):
'''Test for Mesh R-CNN Z head'''

def test_output_shape(self,
num_fc: int,
fc_dim: int,
cls_agnostic: bool,
num_classes: int):
'''Check that Z head output is of correct shape'''

(batch_size, height, width, channels) = (64, 14, 14, 256)
head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes)

test_input = tf.zeros((batch_size, height, width, channels))
output = head(test_input)
expected_output = tf.zeros((batch_size, num_classes))
self.assertAllEqual(output,expected_output)

def test_serialize_deserialize(self,
num_fc: int,
fc_dim: int,
cls_agnostic: bool,
num_classes: int):
"""Create a network object that sets all of its config options."""

(batch_size, height, width, channels) = (64, 14, 14, 256)
head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes)
test_input = tf.zeros((batch_size, height, width, channels))
_ = head(test_input)

serialized = head.get_config()
deserialized = z_head.ZHead.from_config(serialized)

self.assertAllEqual(head.get_config(), deserialized.get_config())

def test_gradient_pass_through(self,
num_fc: int,
fc_dim: int,
cls_agnostic: bool,
num_classes: int):
'''Check that gradients are not None'''

(batch_size, height, width, channels) = (64, 14, 14, 256)
head = z_head.ZHead(num_fc, fc_dim, cls_agnostic, num_classes)

loss = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.SGD()

init = tf.random_normal_initializer()
input_shape = (batch_size, height, width, channels)
test_input = tf.Variable(initial_value = init(shape=input_shape, dtype=tf.float32))

output_shape = head(test_input).shape
test_output = tf.Variable(initial_value = init(shape=output_shape, dtype=tf.float32))

with tf.GradientTape() as tape:
y_hat = head(test_input)
grad_loss = loss(y_hat, test_output)
grad = tape.gradient(grad_loss, head.trainable_variables)
optimizer.apply_gradients(zip(grad, head.trainable_variables))

self.assertNotIn(None, grad)

if __name__ == '__main__':
tf.test.main()
81 changes: 81 additions & 0 deletions official/vision/beta/projects/mesh_rcnn/modeling/heads/z_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mesh R-CNN Heads."""

import tensorflow as tf

class ZHead(tf.keras.layers.Layer):
'''Depth prediction Z Head for Mesh R-CNN model'''
def __init__(self,
num_fc: int,
fc_dim: int,
cls_agnostic: bool,
num_classes: int,
**kwargs):
"""
Initialize Z-head
Args:
num_fc: number of fully connected layers
fc_dim: dimension of fully connected layers
cls_agnostic:
num_classes: number of prediction classes
"""
super(ZHead, self).__init__(**kwargs)

self._num_fc = num_fc
self._fc_dim = fc_dim
self._cls_agnostic = cls_agnostic
self._num_classes = num_classes

def build(self, input_shape: tf.TensorShape) -> None:
'''Build Z Head'''
self.flatten = tf.keras.layers.Flatten()

self.fcs = []
for _ in range(self._num_fc):
fc_init = tf.keras.initializers.VarianceScaling(scale=1.,
mode='fan_in',
distribution='truncated_normal')
layer = tf.keras.layers.Dense(self._fc_dim,
activation='relu',
kernel_initializer=fc_init)
self.fcs.append(layer)
num_z_reg_classes = 1 if self._cls_agnostic else self._num_classes
pred_init = tf.keras.initializers.RandomNormal(stddev=0.001)
self.z_pred = tf.keras.layers.Dense(num_z_reg_classes,
kernel_initializer=pred_init,
bias_initializer='zeros')

def call(self, features):
'''Forward pass of Z head'''
out = self.flatten(features)
for layer in self.fcs:
out = layer(out)
out = self.z_pred(out)
return out

def get_config(self):
"""Get config dict of the ZHead layer."""
config = dict(
num_fc = self._num_fc,
fc_dim = self._fc_dim,
cls_agnostic = self._cls_agnostic,
num_classes = self._num_classes
)
return config

@classmethod
def from_config(cls, config):
'''Initialize Z head from config'''
return cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, List

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -30,6 +30,38 @@ def load_weights(self, layer: tf.keras.layers.Layer) -> int:
n_weights += tf.size(w)
return n_weights

@dataclass
class ZHeadCFG(Config):
weights_dict: Dict = field(repr=False, default=None)
weights: List = field(repr=False, default=None)

fc1_weights: np.array = field(repr=False, default=None)
fc1_bias: np.array = field(repr=False, default=None)
fc2_weights: np.array = field(repr=False, default=None)
fc2_bias: np.array = field(repr=False, default=None)
pred_weights: np.array = field(repr=False, default=None)
pred_bias: np.array = field(repr=False, default=None)

def __post_init__(self):
self.fc1_weights = self.weights_dict['z_fc1']['weight']
self.fc1_bias = self.weights_dict['z_fc1']['bias']
self.fc2_weights = self.weights_dict['z_fc2']['weight']
self.fc2_bias = self.weights_dict['z_fc2']['bias']
self.pred_weights = self.weights_dict['z_pred']['weight']
self.pred_bias = self.weights_dict['z_pred']['bias']

self.weights = [
self.fc1_weights,
self.fc1_bias,
self.fc2_weights,
self.fc2_bias,
self.pred_weights,
self.pred_bias
]

def get_weights(self):
return self.weights

@dataclass
class meshRefinementStageCFG(Config):
weights_dict: Dict = field(repr=False, default=None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@
from typing import Dict

from official.vision.beta.projects.mesh_rcnn.utils.weight_utils.config_classes import \
meshRefinementStageCFG
meshRefinementStageCFG, ZHeadCFG

@dataclass
class ZHeadConfigData():
weights_dict: Dict = field(repr=False, default=None)

def get_cfg_list(self, name):
if name == "pix3d":
return ZHeadCFG(weights_dict=self.weights_dict)
else:
return []

@dataclass
class MeshHeadConfigData():
Expand All @@ -17,6 +26,5 @@ def get_cfg_list(self, name):
meshRefinementStageCFG(weights_dict=self.weights_dict['stages']['1']),
meshRefinementStageCFG(weights_dict=self.weights_dict['stages']['2']),
]

else:
return []
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import tensorflow as tf
from torch import load

from official.vision.beta.projects.mesh_rcnn.modeling.heads.z_head import \
ZHead
from official.vision.beta.projects.mesh_rcnn.modeling.layers.nn_blocks import \
MeshRefinementStage
from official.vision.beta.projects.mesh_rcnn.utils.weight_utils.config_data import \
MeshHeadConfigData
MeshHeadConfigData, ZHeadConfigData


def pth_to_dict(pth_path):
Expand Down Expand Up @@ -52,6 +54,39 @@ def pth_to_dict(pth_path):
return weights_dict, n_read


def get_zhead_layer_cfgs(weights_dict, zhead_name):
""" Fetches the config classes for the z head.
This function generates a list of config classes corresponding to
each building block in the z head.
Args:
weights_dict: Dictionary that stores the z head weights.
zhead_name: String, indicating the desired z head configuration.
Returns:
A list containing the config classes of the mesh head building block.
"""
print("Fetching z head config classes for {}\n".format(zhead_name))
cfgs = ZHeadConfigData(weights_dict).get_cfg_list(zhead_name)
return cfgs


def load_weights_zhead(zhead, weights_dict, zhead_name):
""" Loads the weights defined in the weights_dict into the z head.
This function loads the z head weights by first fetching the necesary
config classes for the backbone, then loads them in one by one for
each layer that has weights associated with it.
Args:
zhead: keras.Model
weights_dict: Dictionary that stores the zhead model weights.
zhead_name: String, indicating the desired zhead configuration.
Returns:
Number of weights loaded in.
"""
print("Loading z head weights\n")
cfgs = get_zhead_layer_cfgs(weights_dict, zhead_name)
cfgs.load_weights(zhead)
return


def get_mesh_head_layer_cfgs(weights_dict, mesh_head_name):
""" Fetches the config classes for the mesh head.
This function generates a list of config classes corresponding to
Expand All @@ -62,7 +97,6 @@ def get_mesh_head_layer_cfgs(weights_dict, mesh_head_name):
Returns:
A list containing the config classes of the mesh head building block.
"""

print("Fetching mesh head config classes for {}\n".format(mesh_head_name))
cfgs = MeshHeadConfigData(weights_dict).get_cfg_list(mesh_head_name)
return cfgs
Expand Down Expand Up @@ -97,3 +131,4 @@ def load_weights_mesh_head(mesh_head, weights_dict, mesh_head_name):
print("{} Weights have been loaded for {} / {} layers\n".format(
n_weights_total, loaded_layers, i))
return n_weights_total

Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from official.vision.beta.projects.mesh_rcnn.modeling.heads.mesh_head import \
MeshHead
from official.vision.beta.projects.mesh_rcnn.modeling.heads.z_head import \
ZHead
from official.vision.beta.projects.mesh_rcnn.ops.cubify import cubify
from official.vision.beta.projects.mesh_rcnn.ops.mesh_ops import \
compute_mesh_shape
from official.vision.beta.projects.mesh_rcnn.ops.visualize_mesh import \
visualize_mesh
from official.vision.beta.projects.mesh_rcnn.utils.weight_utils.load_weights import (
load_weights_mesh_head, pth_to_dict)
load_weights_mesh_head, load_weights_zhead, pth_to_dict)

PTH_PATH = r"C:\ML\Weights\meshrcnn_R50.pth"
BACKBONE_FEATURES = [
Expand All @@ -34,6 +37,26 @@ def print_layer_names(layers_dict, offset=0):
print(" " * offset + k)
print_layer_names(layers_dict[k], offset+2)

def test_load_zhead():
weights_dict, n_read = pth_to_dict(PTH_PATH)
print(weights_dict.keys())
print(weights_dict['roi_heads'].keys())
print(weights_dict['roi_heads']['z_head'].keys())
print(weights_dict['roi_heads']['z_head']['z_pred'].keys())
print(weights_dict['roi_heads']['z_head']['z_pred']['weight'].shape)

input_specs = dict(
num_fc = 2,
fc_dim = 1024,
cls_agnostic = False,
num_classes = 9
)

zhead = ZHead.from_config(input_specs)

n_weights = load_weights_zhead(
zhead, weights_dict['roi_heads']['z_head'], 'pix3d')

def test_load_mesh_refinement_branch():
weights_dict, n_read = pth_to_dict(PTH_PATH)

Expand Down Expand Up @@ -107,3 +130,4 @@ def test_load_mesh_refinement_branch():

if __name__ == '__main__':
test_load_mesh_refinement_branch()
test_load_zhead()