-
Notifications
You must be signed in to change notification settings - Fork 450
Description
So, it's nice to be able to have runtime asserts, but we can't have them in a pure function.
In X86, when the processor encounters an error, it raises the floating point exception flag and goes on. You can check that flag at the end of the computation.
I prototyped a similar mechanism in my code. An error flag is raised during the computation. At the end of the computation, we check whether the error flag has been raised.
Unlike asserts, this method:
- Can be enabled in prod on the same graph. You could place the checks on a device of your choosing.
- Does not know which error came first. If there is a cascade of errors, you will not know. I don't think you can know unless you annotate the graph, or provide a jax op to timestamp the execution.
- Does not interrupt errors which cause infinite loops.
The second issue is the annoying one. My assumption is that you errors should generally happen in the order in which they are constructed.
I've implemented it in my model.py
but it should go in layers.
model.py:
import traceback
def Check(ok_cond: JTensor, errflag: JTensor, errinfo: List[Any]) -> JTensor:
AssertShape(ok_cond, [])
errinfo.append(traceback.extract_stack())
err = jnp.logical_not(ok_cond)
return jnp.concatenate([errflag, err[jnp.newaxis]])
class Model(base_model.BaseModel):
def __init__(...):
self.errinfo = []
self.errflag = jnp.zeros(shape=[0], dtype=bool)
def compute_loss(...):
invariant = jnp.prod((z > 0).astype(int)) == 0
self.errflag = Check(invariant, self.errflag, self.errinfo)
Then, in lingvo/jax/train.py
:
def SafePmap(errsrc, step, **pmap_kwargs):
# errsrc = jax_task.model which contains errflags and errinfo
import numpy as np
def StepWithErrflag(*args, **kwargs):
# The step will construct errflag and errinfo.
# This is pure and can be compiled.
ret = list(step(*args, **kwargs))
ret += [errsrc.errflag]
return ret
def RunCheckCompileStep(*arg, **kwargs):
ret = compiled_step(*arg, **kwargs)
# This runs in python land and we can access python objects.
errinfo = errsrc.errinfo
errflag = ret[-1][0]
if np.sum(errflag):
logging.info('==== ERRORS found: %s', np.sum(errflag))
for k, flag in enumerate(errflag):
if flag:
logging.info('== ERR[%d] at %s', k, '\n'.join(errinfo[k].format()))
raise ValueError('Exception flag raised')
return ret[:-1]
compiled_step = jax.pmap(StepWithErrflag, **pmap_kwargs)
return RunCheckCompileStep
...
def train_and_evaluate_pmap(...):
...
p_train_step = SafePmap(jax_task.model, train_step, donate_argnums=(0,), axis_name='batch')
There is the pjit.pjit
in trainer_lib.py
as well but it's not called in my path.