这是indexloc提供的服务,不要输入任何密码
Skip to content

RFC: lingvo.jax exception flag mechanism #333

@drpngx

Description

@drpngx

So, it's nice to be able to have runtime asserts, but we can't have them in a pure function.

In X86, when the processor encounters an error, it raises the floating point exception flag and goes on. You can check that flag at the end of the computation.

I prototyped a similar mechanism in my code. An error flag is raised during the computation. At the end of the computation, we check whether the error flag has been raised.

Unlike asserts, this method:

  • Can be enabled in prod on the same graph. You could place the checks on a device of your choosing.
  • Does not know which error came first. If there is a cascade of errors, you will not know. I don't think you can know unless you annotate the graph, or provide a jax op to timestamp the execution.
  • Does not interrupt errors which cause infinite loops.

The second issue is the annoying one. My assumption is that you errors should generally happen in the order in which they are constructed.

I've implemented it in my model.py but it should go in layers.

model.py:
import traceback

def Check(ok_cond: JTensor, errflag: JTensor, errinfo: List[Any]) -> JTensor:
  AssertShape(ok_cond, [])
  errinfo.append(traceback.extract_stack())
  err = jnp.logical_not(ok_cond)
  return jnp.concatenate([errflag, err[jnp.newaxis]])


class Model(base_model.BaseModel):
   def __init__(...):
    self.errinfo = []
    self.errflag = jnp.zeros(shape=[0], dtype=bool)

  def compute_loss(...):
    invariant = jnp.prod((z > 0).astype(int)) == 0
    self.errflag = Check(invariant, self.errflag, self.errinfo)

Then, in lingvo/jax/train.py:

def SafePmap(errsrc, step, **pmap_kwargs):
  # errsrc = jax_task.model which contains errflags and errinfo
  import numpy as np
  def StepWithErrflag(*args, **kwargs):
    # The step will construct errflag and errinfo.
    # This is pure and can be compiled.
    ret = list(step(*args, **kwargs))
    ret += [errsrc.errflag]
    return ret

  def RunCheckCompileStep(*arg, **kwargs):
    ret = compiled_step(*arg, **kwargs)
    # This runs in python land and we can access python objects.
    errinfo = errsrc.errinfo
    errflag = ret[-1][0]
    if np.sum(errflag):
      logging.info('==== ERRORS found: %s', np.sum(errflag))
      for k, flag in enumerate(errflag):
        if flag:
          logging.info('== ERR[%d] at %s', k, '\n'.join(errinfo[k].format()))
      raise ValueError('Exception flag raised')
    return ret[:-1]

  compiled_step = jax.pmap(StepWithErrflag, **pmap_kwargs)

  return RunCheckCompileStep

...
def train_and_evaluate_pmap(...):
  ...
  p_train_step = SafePmap(jax_task.model, train_step, donate_argnums=(0,), axis_name='batch')

There is the pjit.pjit in trainer_lib.py as well but it's not called in my path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions