E3nn-jax

Latest version: v0.20.6

Safety actively analyzes 630026 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 8

4.0

y1, y2 = e3nn.utils.equivariance_test(
e3nn.tensor_product, jax.random.PRNGKey(0), x1, x2
)
y1 = R x1 otimes R x2
y2 = R (x1 otimes x2)


Changelog
Changed
- **[BREAKING]** Renamed `e3nn.util` in `e3nn.utils`

Added
- `Irreps.set_mul(int)` to set the multiplicity of all irreps
- `Irreps.filter(lmax=int)` to filter out irreps with `l > lmax`
- `IrrepsArray.filter(lmax=int)` to filter out irreps with `l > lmax`
- `IrrepsArray.__radd__` and `IrrepsArray.__rsub__` to support `scalar + IrrepsArray` and `scalar - IrrepsArray`
- `0 + IrrepsArray` and `0 - IrrepsArray` are now always accepted as special cases.
- Support for `IrrepsArray / array`
- Add `utils` as a submodule

Fixed
- `e3nn.scatter` operation handle indices with `ndim > 1`

2.0

norm_activation(x, [None, jnp.tanh])

1.7767712

Changed
- `e3nn.normalize_function` now uses a deterministic (not pseudorandom) algorithm to compute the normalization factor.

Added
- `normalize_act` option to `e3nn.scalar_activation` and `e3nn.gate`. We can now turn the normalization off if we want to.
- `e3nn.norm_activation` as a new activation function.

1.0

Instead, we have a new attribute `.zero_flags` which is a list of booleans indicating whether the corresponding chunk is zero or not.

python
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
y.chunks [jnp.array([[1.0]]), None]
y.zero_flags [False, True]


`.chunks` is the new attribute that replaces `.list` (now deprecated).
It has a better name because we already have `.slice_by_chunk`.

python
x.chunks list of the two chunks
x.slice_by_chunk[:1] get the first chunk


`jax.vmap` can be used with negative axis

python
jax.vmap(lambda x: x, in_axes=-2)(x)


And the gradient behaves as expected

python
g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array expected value
g.chunks expected value


To avoid any trouble that `.zero_flags` might induce in all [`jax` transformations](https://github.com/google/jax#transformations) we drop it when using a transformation.
python
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
print(y.zero_flags) (False, True)

z = jax.jit(lambda x: x)(y)
print(z.zero_flags) (False, False)

z = jax.tree_util.tree_map(lambda x: x, z)
print(z.zero_flags) (False, False)

z = jax.vmap(lambda x: x)(z[None, ...])
print(z.zero_flags) (False, False)



Changelog

Changed
- **[BREAKING]** `e3nn.flax.Linear` and `e3nn.haiku.Linear` now don't output the impossible irreps anymore. To force the output of all irreps, use `force_irreps_out = True`. For instance `e3nn.flax.Linear("0e + 1o")("0e")` will now return `"0e"` instead of `"0e + 1o"`.
- **[BREAKING]** `e3nn.utils.assert_equivariant` has the same signature as `e3nn.utils.equivariance_test`
- **[BREAKING]** Move `as_irreps_array`, `zeros` and `zeros_like` from `e3nn.IrrepsArray` to `e3nn`
- **[BREAKING]** Move `IrrepsArray.from_list` to `e3nn.from_chunks`
- **[BREAKING]** Rename `IrrepsArray.list` into `IrrepsArray.chunks`
- **[BREAKING]** Rename `IrrepsArray.remove_nones` into `IrrepsArray.remove_zero_chunks`
- `e3nn.IrrepsArray` has now only `.array` as data attribute.

Added
- `e3nn.IrrepsArray.rechunk`
- `e3nn.IrrepsArray.zero_flags` a tuple of bools that indicates which chunks are zero

0.20.6

Added
- `e3nn.where` function
- Add optional `mask` argument in `e3nn.flax.BatchNorm`

Changed
- replace `jnp.ndarray` by `jax.Array`

0.20.5

Added
- `e3nn.ones` and `e3nn.ones_like` functions
- `e3nn.equinox` submodule

Fixed
- python 3.9 compatibility


Thanks to ameya98, SauravMaheshkar and pabloferz

Page 1 of 8

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.