From c0bf0eb27953ef8e194f36f92039d3c51febbd58 Mon Sep 17 00:00:00 2001 From: Nidarshan Siddegowda Date: Thu, 1 Dec 2022 21:53:36 -0500 Subject: [PATCH 1/3] init commit to create_optimizer --- .../projects/mesh_rcnn/tasks/mesh_rcnn.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py b/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py index 019bf252f50..8e7a89581c8 100644 --- a/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py +++ b/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py @@ -4,6 +4,9 @@ from official.core import task_factory from official.vision.beta.projects.mesh_rcnn.configs import mesh_rcnn as exp_cfg from official.vision.beta.projects.mesh_rcnn.modeling import factory +from official.modeling.optimization.optimizer_factory import OptimizerFactory +from official.modeling.optimization import ema_optimizer +from official.modeling import performance import tensorflow as tf @@ -106,7 +109,37 @@ def create_optimizer(self, Returns: A tf.optimizers.Optimizer object. """ - return + opt_factory = optimizer_factory.OptimizerFactory(optimizer_config) + # pylint: disable=protected-access + ema = opt_factory._use_ema + opt_factory._use_ema = False + + opt_type = opt_factory._optimizer_type + ''' + if opt_type == 'sgd_torch': + optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) + optimizer.set_bias_lr( + opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr)) + optimizer.search_and_set_variable_groups(self._model.trainable_variables) + ''' + optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) + opt_factory._use_ema = ema + + if ema: + logging.info('EMA is enabled.') + optimizer = ema_optimizer.ExponentialMovingAverage(optimizer, **self._ema_config.as_dict()) + #optimizer = opt_factory.add_ema(optimizer) + + # pylint: enable=protected-access + + if runtime_config and runtime_config.loss_scale: + use_float16 = runtime_config.mixed_precision_dtype == 'float16' + optimizer = performance.configure_optimizer( + optimizer, + use_float16=use_float16, + loss_scale=runtime_config.loss_scale) + + return optimizer From 64e170c2b0fa634e749e3aedeab01e98b0fb2a77 Mon Sep 17 00:00:00 2001 From: Nidarshan Siddegowda Date: Sat, 10 Dec 2022 14:48:42 -0500 Subject: [PATCH 2/3] added adam optimizer --- .../projects/mesh_rcnn/configs/mesh_rcnn.py | 49 +++ .../mesh_rcnn/optimization/__init__.py | 22 ++ .../optimization/configs/__init__.py | 14 + .../configs/optimization_config.py | 56 ++++ .../optimization/configs/optimizer_config.py | 63 ++++ .../optimization/optimizer_factory.py | 99 ++++++ .../mesh_rcnn/optimization/sgd_torch.py | 313 ++++++++++++++++++ .../projects/mesh_rcnn/tasks/mesh_rcnn.py | 25 +- 8 files changed, 627 insertions(+), 14 deletions(-) create mode 100755 official/vision/beta/projects/mesh_rcnn/optimization/__init__.py create mode 100755 official/vision/beta/projects/mesh_rcnn/optimization/configs/__init__.py create mode 100755 official/vision/beta/projects/mesh_rcnn/optimization/configs/optimization_config.py create mode 100755 official/vision/beta/projects/mesh_rcnn/optimization/configs/optimizer_config.py create mode 100755 official/vision/beta/projects/mesh_rcnn/optimization/optimizer_factory.py create mode 100644 official/vision/beta/projects/mesh_rcnn/optimization/sgd_torch.py diff --git a/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py b/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py index 9ae5f0973dc..40b003d890c 100644 --- a/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py +++ b/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py @@ -16,6 +16,12 @@ import dataclasses from official.modeling import hyperparams # type: ignore +from official.core import config_definitions as cfg +from official.core import exp_factory +from official.modeling import hyperparams +from official.vision.beta.projects.mesh_rcnn import optimization +from official.vision.beta.projects.mesh_rcnn.tasks import mesh_rcnn +from official.vision.configs import common @dataclasses.dataclass class ZHead(hyperparams.Config): @@ -54,3 +60,46 @@ class MeshLosses(hyperparams.Config): edge_weight: float = 0.1 true_num_samples: int = 5000 pred_num_samples: int = 5000 + +@exp_factory.register_config_factory('mesh_training') +def mesh_training() -> cfg.ExperimentConfig: + """COCO object detection with YOLOv3 and v4.""" + train_batch_size = 256 + eval_batch_size = 8 + train_epochs = 300 + steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size + validation_interval = 5 + + max_num_instances = 200 + config = cfg.ExperimentConfig( + trainer=cfg.TrainerConfig( + train_steps=train_epochs * steps_per_epoch, + validation_steps=COCO_VAL_EXAMPLES // eval_batch_size, + validation_interval=validation_interval * steps_per_epoch, + steps_per_loop=steps_per_epoch, + summary_interval=steps_per_epoch, + checkpoint_interval=steps_per_epoch, + optimizer_config=optimization.OptimizationConfig({ + 'ema': { + 'average_decay': 0.9998, + 'trainable_weights_only': False, + 'dynamic_decay': True, + }, + 'optimizer': { + 'type': 'adam', + 'sgd_torch': { + 'learning_rate' : 0.001, + 'beta_1' : 0.9, + 'beta_2' : 0.999, + 'epsilon' : 1e-07 + } + }, + 'learning_rate': {}, + 'warmup': {} + })), + restrictions=[ + 'task.train_data.is_training != None', + 'task.validation_data.is_training != None' + ]) + + return config \ No newline at end of file diff --git a/official/vision/beta/projects/mesh_rcnn/optimization/__init__.py b/official/vision/beta/projects/mesh_rcnn/optimization/__init__.py new file mode 100755 index 00000000000..635503a495c --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/optimization/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimization package definition.""" + +# pylint: disable=wildcard-import +from official.modeling.optimization.configs.learning_rate_config import * +from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage +from official.vision.beta.projects.mesh_rcnn.optimization.configs.optimization_config import * +from official.vision.beta.projects.mesh_rcnn.optimization.configs.optimizer_config import * +from official.vision.beta.projects.mesh_rcnn.optimization.optimizer_factory import OptimizerFactory as MeshOptimizerFactory diff --git a/official/vision/beta/projects/mesh_rcnn/optimization/configs/__init__.py b/official/vision/beta/projects/mesh_rcnn/optimization/configs/__init__.py new file mode 100755 index 00000000000..310bfb28f0c --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/optimization/configs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/official/vision/beta/projects/mesh_rcnn/optimization/configs/optimization_config.py b/official/vision/beta/projects/mesh_rcnn/optimization/configs/optimization_config.py new file mode 100755 index 00000000000..597a389e8ac --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/optimization/configs/optimization_config.py @@ -0,0 +1,56 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataclasses for optimization configs. + +This file define the dataclass for optimization configs (OptimizationConfig). +It also has two helper functions get_optimizer_config, and get_lr_config from +an OptimizationConfig class. +""" +import dataclasses +from typing import Optional + +from official.modeling.optimization.configs import optimization_config as optimization_cfg +from official.vision.beta.projects.mesh_rcnn.optimization.configs import optimizer_config as opt_cfg + + +@dataclasses.dataclass +class OptimizerConfig(optimization_cfg.OptimizerConfig): + """Configuration for optimizer. + + Attributes: + type: 'str', type of optimizer to be used, on the of fields below. + sgd: sgd optimizer config. + adam: adam optimizer config. + adamw: adam with weight decay. + lamb: lamb optimizer. + rmsprop: rmsprop optimizer. + """ + type: Optional[str] = None + sgd_torch: opt_cfg.SGDTorchConfig = opt_cfg.SGDTorchConfig() + + +@dataclasses.dataclass +class OptimizationConfig(optimization_cfg.OptimizationConfig): + """Configuration for optimizer and learning rate schedule. + + Attributes: + optimizer: optimizer oneof config. + ema: optional exponential moving average optimizer config, if specified, ema + optimizer will be used. + learning_rate: learning rate oneof config. + warmup: warmup oneof config. + """ + type: Optional[str] = None + optimizer: OptimizerConfig = OptimizerConfig() diff --git a/official/vision/beta/projects/mesh_rcnn/optimization/configs/optimizer_config.py b/official/vision/beta/projects/mesh_rcnn/optimization/configs/optimizer_config.py new file mode 100755 index 00000000000..46c9609649c --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/optimization/configs/optimizer_config.py @@ -0,0 +1,63 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataclasses for optimizer configs.""" +import dataclasses +from typing import List, Optional + +from official.modeling.hyperparams import base_config +from official.modeling.optimization.configs import optimizer_config + + +@dataclasses.dataclass +class BaseOptimizerConfig(base_config.Config): + """Base optimizer config. + + Attributes: + clipnorm: float >= 0 or None. If not None, Gradients will be clipped when + their L2 norm exceeds this value. + clipvalue: float >= 0 or None. If not None, Gradients will be clipped when + their absolute value exceeds this value. + global_clipnorm: float >= 0 or None. If not None, gradient of all weights is + clipped so that their global norm is no higher than this value + """ + clipnorm: Optional[float] = None + clipvalue: Optional[float] = None + global_clipnorm: Optional[float] = None + + +@dataclasses.dataclass +class SGDTorchConfig(optimizer_config.BaseOptimizerConfig): + """Configuration for SGD optimizer. + + The attributes for this class matches the arguments of tf.keras.optimizer.SGD. + + Attributes: + name: name of the optimizer. + decay: decay rate for SGD optimizer. + nesterov: nesterov for SGD optimizer. + momentum_start: momentum starting point for SGD optimizer. + momentum: momentum for SGD optimizer. + """ + name: str = "SGD" + decay: float = 0.0 + nesterov: bool = False + momentum_start: float = 0.0 + momentum: float = 0.9 + warmup_steps: int = 0 + weight_decay: float = 0.0 + weight_keys: Optional[List[str]] = dataclasses.field( + default_factory=lambda: ["kernel", "weight"]) + bias_keys: Optional[List[str]] = dataclasses.field( + default_factory=lambda: ["bias", "beta"]) diff --git a/official/vision/beta/projects/mesh_rcnn/optimization/optimizer_factory.py b/official/vision/beta/projects/mesh_rcnn/optimization/optimizer_factory.py new file mode 100755 index 00000000000..4780aaf4cda --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/optimization/optimizer_factory.py @@ -0,0 +1,99 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimizer factory class.""" + +import gin + +from official.modeling.optimization import ema_optimizer +from official.modeling.optimization import optimizer_factory +from official.vision.beta.projects.mesh_rcnn.optimization import sgd_torch + +optimizer_factory.OPTIMIZERS_CLS.update({ + 'sgd_torch': sgd_torch.SGDTorch, +}) + +OPTIMIZERS_CLS = optimizer_factory.OPTIMIZERS_CLS +LR_CLS = optimizer_factory.LR_CLS +WARMUP_CLS = optimizer_factory.WARMUP_CLS + + +class OptimizerFactory(optimizer_factory.OptimizerFactory): + """Optimizer factory class. + + This class builds learning rate and optimizer based on an optimization config. + To use this class, you need to do the following: + (1) Define optimization config, this includes optimizer, and learning rate + schedule. + (2) Initialize the class using the optimization config. + (3) Build learning rate. + (4) Build optimizer. + + This is a typical example for using this class: + params = { + 'optimizer': { + 'type': 'sgd', + 'sgd': {'momentum': 0.9} + }, + 'learning_rate': { + 'type': 'stepwise', + 'stepwise': {'boundaries': [10000, 20000], + 'values': [0.1, 0.01, 0.001]} + }, + 'warmup': { + 'type': 'linear', + 'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01} + } + } + opt_config = OptimizationConfig(params) + opt_factory = OptimizerFactory(opt_config) + lr = opt_factory.build_learning_rate() + optimizer = opt_factory.build_optimizer(lr) + """ + + def get_bias_lr_schedule(self, bias_lr): + """Build learning rate. + + Builds learning rate from config. Learning rate schedule is built according + to the learning rate config. If learning rate type is consant, + lr_config.learning_rate is returned. + + Args: + bias_lr: learning rate config. + + Returns: + tf.keras.optimizers.schedules.LearningRateSchedule instance. If + learning rate type is consant, lr_config.learning_rate is returned. + """ + if self._lr_type == 'constant': + lr = self._lr_config.learning_rate + else: + lr = LR_CLS[self._lr_type](**self._lr_config.as_dict()) + + if self._warmup_config: + if self._warmup_type != 'linear': + raise ValueError('Smart Bias is only supported currently with a' + 'linear warm up.') + warm_up_cfg = self._warmup_config.as_dict() + warm_up_cfg['warmup_learning_rate'] = bias_lr + lr = WARMUP_CLS['linear'](lr, **warm_up_cfg) + return lr + + @gin.configurable + def add_ema(self, optimizer): + """Add EMA to the optimizer independently of the build optimizer method.""" + if self._use_ema: + optimizer = ema_optimizer.ExponentialMovingAverage( + optimizer, **self._ema_config.as_dict()) + return optimizer diff --git a/official/vision/beta/projects/mesh_rcnn/optimization/sgd_torch.py b/official/vision/beta/projects/mesh_rcnn/optimization/sgd_torch.py new file mode 100644 index 00000000000..5f372a2c5b6 --- /dev/null +++ b/official/vision/beta/projects/mesh_rcnn/optimization/sgd_torch.py @@ -0,0 +1,313 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SGD PyTorch optimizer.""" +import re + +from absl import logging +import tensorflow as tf + +LearningRateSchedule = tf.keras.optimizers.schedules.LearningRateSchedule + + +def _var_key(var): + """Key for representing a primary variable, for looking up slots. + + In graph mode the name is derived from the var shared name. + In eager mode the name is derived from the var unique id. + If distribution strategy exists, get the primary variable first. + Args: + var: the variable. + + Returns: + the unique name of the variable. + """ + + # pylint: disable=protected-access + # Get the distributed variable if it exists. + if hasattr(var, "_distributed_container"): + var = var._distributed_container() + if var._in_graph_mode: + return var._shared_name + return var._unique_id + + +class SGDTorch(tf.keras.optimizers.Optimizer): + """Optimizer that simulates the SGD module used in pytorch. + + + For details on the differences between the original SGD implemention and the + one in pytorch: + https://pytorch.org/docs/stable/generated/torch.optim.SGD.html. + This optimizer also allow for the usage of a momentum warmup along side a + learning rate warm up, though using this is not required. + + Example of usage for training: + ```python + opt = SGDTorch(learning_rate, weight_decay = 0.0001) + l2_regularization = None + + # iterate all model.trainable_variables and split the variables by key + # into the weights, biases, and others. + optimizer.search_and_set_variable_groups(model.trainable_variables) + + # if the learning rate schedule on the biases are different. if lr is not set + # the default schedule used for weights will be used on the biases. + opt.set_bias_lr() + + # if the learning rate schedule on the others are different. if lr is not set + # the default schedule used for weights will be used on the biases. + opt.set_other_lr() + ``` + """ + + _HAS_AGGREGATE_GRAD = True + + def __init__(self, + weight_decay=0.0, + learning_rate=0.01, + momentum=0.0, + momentum_start=0.0, + warmup_steps=1000, + nesterov=False, + name="SGD", + weight_keys=("kernel", "weight"), + bias_keys=("bias", "beta"), + **kwargs): + super(SGDTorch, self).__init__(name, **kwargs) + + # Create Hyper Params for each group of the LR + self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) + self._set_hyper("bias_learning_rate", kwargs.get("lr", learning_rate)) + self._set_hyper("other_learning_rate", kwargs.get("lr", learning_rate)) + + # SGD decay param + self._set_hyper("decay", self._initial_decay) + + # Weight decay param + self._weight_decay = weight_decay != 0.0 + self._set_hyper("weight_decay", weight_decay) + + # Enable Momentum + self._momentum = False + if isinstance(momentum, tf.Tensor) or callable(momentum) or momentum > 0: + self._momentum = True + if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): + raise ValueError("`momentum` must be between [0, 1].") + self._set_hyper("momentum", momentum) + self._set_hyper("momentum_start", momentum_start) + self._set_hyper("warmup_steps", tf.cast(warmup_steps, tf.int32)) + + # Enable Nesterov Momentum + self.nesterov = nesterov + + # weights, biases, other + self._weight_keys = weight_keys + self._bias_keys = bias_keys + self._variables_set = False + self._wset = set() + self._bset = set() + self._oset = set() + + logging.info("Pytorch SGD simulation: ") + logging.info("Weight Decay: %f", weight_decay) + + def set_bias_lr(self, lr): + self._set_hyper("bias_learning_rate", lr) + + def set_other_lr(self, lr): + self._set_hyper("other_learning_rate", lr) + + def _search(self, var, keys): + """Search all all keys for matches. Return True on match.""" + if keys is not None: + # variable group is not ignored so search for the keys. + for r in keys: + if re.search(r, var.name) is not None: + return True + return False + + def search_and_set_variable_groups(self, variables): + """Search all variable for matches at each group.""" + weights = [] + biases = [] + others = [] + + for var in variables: + + if self._search(var, self._weight_keys): + # search for weights + weights.append(var) + elif self._search(var, self._bias_keys): + # search for biases + biases.append(var) + else: + # if all searches fail, add to other group + others.append(var) + + self._set_variable_groups(weights, biases, others) + return weights, biases, others + + def _set_variable_groups(self, weights, biases, others): + """Sets the variables to be used in each group.""" + + if self._variables_set: + logging.warning("_set_variable_groups has been called again indicating" + "that the variable groups have already been set, they" + "will be updated.") + self._wset.update(set([_var_key(w) for w in weights])) + self._bset.update(set([_var_key(b) for b in biases])) + self._oset.update(set([_var_key(o) for o in others])) + self._variables_set = True + return + + def _get_variable_group(self, var, coefficients): + if self._variables_set: + # check which groups hold which varaibles, preset. + if _var_key(var) in self._wset: + return True, False, False + elif _var_key(var) in self._bset: + return False, True, False + else: + # search the variables at run time. + if self._search(var, self._weight_keys): + return True, False, False + elif self._search(var, self._bias_keys): + return False, True, False + return False, False, True + + def _create_slots(self, var_list): + """Create a momentum variable for each variable.""" + if self._momentum: + for var in var_list: + # check if trainable to support GPU EMA. + if var.trainable: + self.add_slot(var, "momentum") + + def _get_momentum(self, iteration): + """Get the momentum value.""" + momentum = self._get_hyper("momentum") + momentum_start = self._get_hyper("momentum_start") + momentum_warm_up_steps = tf.cast( + self._get_hyper("warmup_steps"), iteration.dtype) + value = tf.cond( + (iteration - momentum_warm_up_steps) <= 0, + true_fn=lambda: (momentum_start + # pylint: disable=g-long-lambda + (tf.cast(iteration, momentum.dtype) * + (momentum - momentum_start) / tf.cast( + momentum_warm_up_steps, momentum.dtype))), + false_fn=lambda: momentum) + return value + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(SGDTorch, self)._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error + weight_decay = self._get_hyper("weight_decay") + apply_state[(var_device, + var_dtype)]["weight_decay"] = tf.cast(weight_decay, var_dtype) + + if self._momentum: + momentum = self._get_momentum(self.iterations) + momentum = tf.cast(momentum, var_dtype) + apply_state[(var_device, + var_dtype)]["momentum"] = tf.identity(momentum) + + bias_lr = self._get_hyper("bias_learning_rate") + if isinstance(bias_lr, LearningRateSchedule): + bias_lr = bias_lr(self.iterations) + bias_lr = tf.cast(bias_lr, var_dtype) + apply_state[(var_device, + var_dtype)]["bias_lr_t"] = tf.identity(bias_lr) + + other_lr = self._get_hyper("other_learning_rate") + if isinstance(other_lr, LearningRateSchedule): + other_lr = other_lr(self.iterations) + other_lr = tf.cast(other_lr, var_dtype) + apply_state[(var_device, + var_dtype)]["other_lr_t"] = tf.identity(other_lr) + + return apply_state[(var_device, var_dtype)] + + def _apply(self, grad, var, weight_decay, momentum, lr): + """Uses Pytorch Optimizer with Weight decay SGDW.""" + dparams = grad + groups = [] + + # do not update non-trainable weights + if not var.trainable: + return tf.group(*groups) + + if self._weight_decay: + dparams += (weight_decay * var) + + if self._momentum: + momentum_var = self.get_slot(var, "momentum") + momentum_update = momentum_var.assign( + momentum * momentum_var + dparams, use_locking=self._use_locking) + groups.append(momentum_update) + + if self.nesterov: + dparams += (momentum * momentum_update) + else: + dparams = momentum_update + + weight_update = var.assign_add(-lr * dparams, use_locking=self._use_locking) + groups.append(weight_update) + return tf.group(*groups) + + def _run_sgd(self, grad, var, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) or + self._fallback_apply_state(var_device, var_dtype)) + + weights, bias, others = self._get_variable_group(var, coefficients) + weight_decay = tf.zeros_like(coefficients["weight_decay"]) + lr = coefficients["lr_t"] + if weights: + weight_decay = coefficients["weight_decay"] + lr = coefficients["lr_t"] + elif bias: + weight_decay = tf.zeros_like(coefficients["weight_decay"]) + lr = coefficients["bias_lr_t"] + elif others: + weight_decay = tf.zeros_like(coefficients["weight_decay"]) + lr = coefficients["other_lr_t"] + momentum = coefficients["momentum"] + + return self._apply(grad, var, weight_decay, momentum, lr) + + def _resource_apply_dense(self, grad, var, apply_state=None): + return self._run_sgd(grad, var, apply_state=apply_state) + + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + # This method is only needed for momentum optimization. + holder = tf.tensor_scatter_nd_add( + tf.zeros_like(var), tf.expand_dims(indices, axis=-1), grad) + return self._run_sgd(holder, var, apply_state=apply_state) + + def get_config(self): + config = super(SGDTorch, self).get_config() + config.update({ + "learning_rate": self._serialize_hyperparameter("learning_rate"), + "decay": self._initial_decay, + "momentum": self._serialize_hyperparameter("momentum"), + "momentum_start": self._serialize_hyperparameter("momentum_start"), + "weight_decay": self._serialize_hyperparameter("weight_decay"), + "warmup_steps": self._serialize_hyperparameter("warmup_steps"), + "nesterov": self.nesterov, + }) + return config + + @property + def learning_rate(self): + return self._optimizer._get_hyper("learning_rate") # pylint: disable=protected-access diff --git a/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py b/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py index 8e7a89581c8..90bcb2576a8 100644 --- a/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py +++ b/official/vision/beta/projects/mesh_rcnn/tasks/mesh_rcnn.py @@ -4,7 +4,7 @@ from official.core import task_factory from official.vision.beta.projects.mesh_rcnn.configs import mesh_rcnn as exp_cfg from official.vision.beta.projects.mesh_rcnn.modeling import factory -from official.modeling.optimization.optimizer_factory import OptimizerFactory +from official.vision.beta.projects.mesh_rcnn import optimization from official.modeling.optimization import ema_optimizer from official.modeling import performance @@ -12,16 +12,15 @@ @task_factory.register_task_cls(exp_cfg.MeshRCNNTask) class MeshRCNNTask(base_task.Task): - """A single-replica view of training procedure. + """A single-replica view of training procedure. - MeshRCNN task provides artifacts for training/evalution procedures, including - loading/iterating over Datasets, initializing the model, calculating the loss, - post-processing, and customized metrics with reduction. - """ + MeshRCNN task provides artifacts for training/evalution procedures, including + loading/iterating over Datasets, initializing the model, calculating the loss, + post-processing, and customized metrics with reduction. + """ def __init__(self, params, logging_dir: Optional[str] = None): super().__init__(params, logging_dir) return - def build_model(self): """Build Mesh R-CNN model.""" @@ -47,7 +46,7 @@ def build_model(self): def build_inputs(self, params, input_context=None): """Build input dataset.""" return - + def build_metrics(self, training=True): """Build metrics.""" return @@ -88,7 +87,7 @@ def validation_step(self, inputs, model, metrics=None): A dictionary of logs. """ return - + def aggregate_logs(self, state=None, step_outputs=None): """Get Metric Results.""" return @@ -98,8 +97,8 @@ def reduce_aggregated_logs(self, aggregated_logs, global_step=None): return def create_optimizer(self, - optimizer_config: OptimizationConfig, - runtime_config: Optional[RuntimeConfig] = None): + optimizer_config: OptimizationConfig, + runtime_config: Optional[RuntimeConfig] = None): """Creates an TF optimizer from configurations. Args: @@ -109,19 +108,17 @@ def create_optimizer(self, Returns: A tf.optimizers.Optimizer object. """ - opt_factory = optimizer_factory.OptimizerFactory(optimizer_config) + opt_factory = optimization.MeshOptimizerFactory(optimizer_config) # pylint: disable=protected-access ema = opt_factory._use_ema opt_factory._use_ema = False opt_type = opt_factory._optimizer_type - ''' if opt_type == 'sgd_torch': optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer.set_bias_lr( opt_factory.get_bias_lr_schedule(self._task_config.smart_bias_lr)) optimizer.search_and_set_variable_groups(self._model.trainable_variables) - ''' optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) opt_factory._use_ema = ema From b9022cd5d71686a56627860db6de5aa1b9a0c0e8 Mon Sep 17 00:00:00 2001 From: Nidarshan Siddegowda <34317175+nidarshans@users.noreply.github.com> Date: Sat, 10 Dec 2022 17:38:55 -0500 Subject: [PATCH 3/3] Update mesh_rcnn.py --- official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py b/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py index 40b003d890c..be2d13c763c 100644 --- a/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py +++ b/official/vision/beta/projects/mesh_rcnn/configs/mesh_rcnn.py @@ -87,7 +87,7 @@ def mesh_training() -> cfg.ExperimentConfig: }, 'optimizer': { 'type': 'adam', - 'sgd_torch': { + 'adam': { 'learning_rate' : 0.001, 'beta_1' : 0.9, 'beta_2' : 0.999, @@ -102,4 +102,4 @@ def mesh_training() -> cfg.ExperimentConfig: 'task.validation_data.is_training != None' ]) - return config \ No newline at end of file + return config