
Latest version: v0.9.0

Safety actively analyzes 666166 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 7


- fixed rng guide outputs by chiamp in https://github.com/google/flax/pull/3685
- enforce mask kwarg in norm layers by chiamp in https://github.com/google/flax/pull/3663
- added kwargs to self.param and self.variable by chiamp in https://github.com/google/flax/pull/3675
- added nnx normalization tests by chiamp in https://github.com/google/flax/pull/3689
- added NNX init_cache docstring example by chiamp in https://github.com/google/flax/pull/3688
- added nnx attention equivalence test by chiamp in https://github.com/google/flax/pull/3687
- Fix bug that assumed frozen-dict keys were strings. by copybara-service in https://github.com/google/flax/pull/3692
- added nnx rmsnorm by chiamp in https://github.com/google/flax/pull/3691
- updated nnx compute_stats by chiamp in https://github.com/google/flax/pull/3693
- fixed intercept_methods docstring by chiamp in https://github.com/google/flax/pull/3694
- [nnx] Add Sphinx Docs by cgarciae in https://github.com/google/flax/pull/3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by levskaya in https://github.com/google/flax/pull/3703
- added default params rng to .apply by chiamp in https://github.com/google/flax/pull/3698
- [nnx] add partial_init by cgarciae in https://github.com/google/flax/pull/3674
- make make_rng default to 'params' by chiamp in https://github.com/google/flax/pull/3699
- Add SimpleCell. by carlosgmartin in https://github.com/google/flax/pull/3697
- fix Module.module_paths docstring by cgarciae in https://github.com/google/flax/pull/3709
- Guarantee the latest JAX version on CI by cgarciae in https://github.com/google/flax/pull/3705
- Replace deprecated API `jax.tree.map` by copybara-service in https://github.com/google/flax/pull/3715
- Use `jax.tree_util.tree_map` instead of deprecated `jax.tree.map`. by copybara-service in https://github.com/google/flax/pull/3714
- [nnx] simplify readme by cgarciae in https://github.com/google/flax/pull/3707
- [nnx] add demo.ipynb by cgarciae in https://github.com/google/flax/pull/3680
- Fix Tabulate's compute_flops by cgarciae in https://github.com/google/flax/pull/3721
- [nnx] simplify TraceState by cgarciae in https://github.com/google/flax/pull/3724
- Add broadcast of `strides` and `kernel_dilation` to `nn.ConvTranspose` by IvyZX in https://github.com/google/flax/pull/3731
- [nnx] Fix State.__sub__ by cgarciae in https://github.com/google/flax/pull/3704
- [nnx] always fold_in on fork + new ForkedKeys return type by cgarciae in https://github.com/google/flax/pull/3722
- [nnx] explicit Variables by cgarciae in https://github.com/google/flax/pull/3720
- Improves fingerprint definition for Modules in nn.jit. by copybara-service in https://github.com/google/flax/pull/3736
- Flax: avoid key reuse in tests by copybara-service in https://github.com/google/flax/pull/3740
- added Einsum layer by chiamp in https://github.com/google/flax/pull/3710
- nn.jit: automatic fingerprint definition for dataclass attributes by cgarciae in https://github.com/google/flax/pull/3737
- [NVIDIA] Use custom grad accumulation for FP8 params by kaixih in https://github.com/google/flax/pull/3623
- removed nnx dataclass by chiamp in https://github.com/google/flax/pull/3742
- [nnx] cleanup graph_utils by cgarciae in https://github.com/google/flax/pull/3728
- Fix doctest and unbreak head by IvyZX in https://github.com/google/flax/pull/3753
- [nnx] add pytree support by cgarciae in https://github.com/google/flax/pull/3732
- fixed intercept_methods docstring by chiamp in https://github.com/google/flax/pull/3752
- Add ConvLSTMCell to docs. by carlosgmartin in https://github.com/google/flax/pull/3712
- [nnx] remove flagslib by cgarciae in https://github.com/google/flax/pull/3733
- Fix tests after applying JAX key-reuse checker. See: by copybara-service in https://github.com/google/flax/pull/3748


- Added default collection in `make_rng`.
- Added `InstanceNorm` and renamed `channel_axes` to `feature_axes`.
- Added norm equivalence tests.
- Added `Module.module_paths` and doc.
- make `Sequential.__call__` compact.
- Added `nn.compact_name_scope` v3.
- Add explicit control over frozen/slots setting in `flax.struct.dataclass`.
- Replacing `jax.tree_util.tree_map` with mapping over leafs.
- Fixed docs and docstrings.


- Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
- Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier.
- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better
defaults for common use cases.
- Added [`BatchApply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#batchapply) class.
- Added `sow_weights` option in attention layer.
- Added [`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.MultiHeadAttention.html) alias.
- Added kwargs support for `nn.jit`.
- Deprecated `normalize` activation function, in favor of `standardize`.
- Added `GeGLU` activation function.
- Added `Enum` support for `tabulate` function.
- Added simple argument-only lifted `nn.grad` function.


- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns.
- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [3389](https://github.com/google/flax/discussions/3389).
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/google/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
- Added `has_improved` field to EarlyStopping and changed the return signature of
`EarlyStopping.update` from returning a tuple to returning just the updated class.
See more details in [3385](https://github.com/google/flax/pull/3385)


New features:
- Add QK-normalization to MultiHeadDotProductAttention
- Allow apply's method argument to accept submodules
- Add module path to nn.module.
- [JAX] Generate new type of PRNG keys

Bug fixes:
- Directly call original method if method interceptor stack is empty.
- fix stackoverflow when loading pickled module
- Improve kw_only_dataclass.
- Allow pass-through implementation of state dict
- Promote dot_general injections from a function to a module.


New features:
- make `flax.core.copy` `add_or_replace` optional
- Add `use_fast_variance` option to `GroupNorm` and `BatchNorm` to allow disabling it.

Bug fixes:
- Use `field_specifiers` instead of `field_descriptors` in `dataclass_transform`.
- Fix `nn.Module` typing.
- [JAX] Replace uses of `jax.experimental.pjit.with_sharding_constraint` with `jax.lax.with_sharding_constraint`.

Page 1 of 7

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.