Flax

Latest version: v0.10.5

Safety actively analyzes 723650 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 2 of 7

0.7.1

-----
Breaking changes:
- Migrating Flax from returning FrozenDicts to returning regular dicts. More details can be found in this [announcement](https://github.com/google/flax/discussions/3191)

New features:
- Use pyink
- added dict migration guide to index
- add scan over layers section
- Expose options to customize rich.Table
- add support for initializing carry variables in scan
- Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings.

Bug fixes:
- Use import `orbax.checkpoint` which is a better import pattern.
- Use import `orbax.checkpoint as ocp` to avoid the verbosity of using 'orbax.checkpoint` every time.
- [linen] Add alternative, more numerically stable, variance calculation to `LayerNorm`.
- [linen] Minor cleanup to normalization code.
- Fix norm calculation bug for 0-rank arrays.
- [JAX] Remove references to jax.config.jax_array.
- [linen] Use `stack` instead of `concatenate` in `compute_stats`, to handle scalar stats case.
- [linen] More minor cleanup in normalization `compute_stats`.
- Fix warnings from atari gym.
- Refactor TypeHandler to operate over batches of values, rather than individual ones. This allows more flexibility for implementations that may operate more efficiently on batches.
- Fix carry slice logic
- make flax_basics guide use utility fns
- Fix checkpointing guide error at head
- Improve scan docs

0.7.0

-----
- RNNCellBase refactor.

0.6.11

-----
- Set Orbax-as-backend to be the default checkpointing method.
- Fix setup trigger issue under sharing and transforms.
- Add collection to self.scope.reserve(name, col) so that sow works with the same name in different collections.
- Minor improvements for Sequential.
- Improve the error message in MultiHeadDotProductAttention.
- Allow manually specifying the rng key for Dropout.
- RNN refactor.
- fixed missing separator for rng fold in.

0.6.10

-----
- Rudimentary quantization support: some layers can be parametrized with custom dot_general and conv_general_dilated.

0.6.9

-----
- Depend on `orbax-checkpoint` package instead of `orbax`.
- Refactored setup scripts to `project.toml`.
- Added pretty_repr utility fn.
- Fix get_partition_spec on replicated array.
- Updates imagenet.ipynb to use GPU Colab runtime, and fixed config.
- Upgrade checkpointing code to `jax.sharding`, and with more warnings.

0.6.8

-----
- The automatic checkpoint migration was temporarily rolled back due to legacy compatibility issues.
- We still recommend you to use the [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and migrate completely to the Orbax API to ensure stability.
- Or alternatively, add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project to avoid being impacted by the automatic migration process.
- Added utility functions to frozen_dict api.
- Migrated Flax away from `register_keypaths`.
- Fixes kwargs in convert_to_graphs_tuple_fn.
- Fixed examples in a few ways:
- Bumped the TF version
- Used latest checkpoint formats
- Other misc fixes.

Page 2 of 7

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.