I tried ``` def AssertShape(x: jnp.array, shape) -> None: if not jnp.array_equal(x.shape, shape): raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}') ``` and got ``` jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. ``` (BTW, note the double period)