Instead, we have a new attribute `.zero_flags` which is a list of booleans indicating whether the corresponding chunk is zero or not.
python
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
y.chunks [jnp.array([[1.0]]), None]
y.zero_flags [False, True]
`.chunks` is the new attribute that replaces `.list` (now deprecated).
It has a better name because we already have `.slice_by_chunk`.
python
x.chunks list of the two chunks
x.slice_by_chunk[:1] get the first chunk
`jax.vmap` can be used with negative axis
python
jax.vmap(lambda x: x, in_axes=-2)(x)
And the gradient behaves as expected
python
g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array expected value
g.chunks expected value
To avoid any trouble that `.zero_flags` might induce in all [`jax` transformations](https://github.com/google/jax#transformations) we drop it when using a transformation.
python
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
print(y.zero_flags) (False, True)
z = jax.jit(lambda x: x)(y)
print(z.zero_flags) (False, False)
z = jax.tree_util.tree_map(lambda x: x, z)
print(z.zero_flags) (False, False)
z = jax.vmap(lambda x: x)(z[None, ...])
print(z.zero_flags) (False, False)
Changelog
Changed
- **[BREAKING]** `e3nn.flax.Linear` and `e3nn.haiku.Linear` now don't output the impossible irreps anymore. To force the output of all irreps, use `force_irreps_out = True`. For instance `e3nn.flax.Linear("0e + 1o")("0e")` will now return `"0e"` instead of `"0e + 1o"`.
- **[BREAKING]** `e3nn.utils.assert_equivariant` has the same signature as `e3nn.utils.equivariance_test`
- **[BREAKING]** Move `as_irreps_array`, `zeros` and `zeros_like` from `e3nn.IrrepsArray` to `e3nn`
- **[BREAKING]** Move `IrrepsArray.from_list` to `e3nn.from_chunks`
- **[BREAKING]** Rename `IrrepsArray.list` into `IrrepsArray.chunks`
- **[BREAKING]** Rename `IrrepsArray.remove_nones` into `IrrepsArray.remove_zero_chunks`
- `e3nn.IrrepsArray` has now only `.array` as data attribute.
Added
- `e3nn.IrrepsArray.rechunk`
- `e3nn.IrrepsArray.zero_flags` a tuple of bools that indicates which chunks are zero