Flax

Latest version: v0.9.0

Safety actively analyzes 666166 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 5 of 7

0.4.2

-----

New features:
- Add lifted conditional `nn.cond`.
- Improved error messages: parameters not found, loading checkpoints.
- Replace `jax.tree_multimap` (deprecated) with `jax.tree.map`.
- Add the "Module Lifecycle" design note.
- Add support for JAX dynamic stack-based named_call

Bug fixes:
- Handle rate==1.0 edgecase in Dropout.
- Fix bug where Linen Module state is reused.
- Bug fixes and generalizations of nn.partitioning API.

0.4.1

-----

New features:
- Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`.
- Improved seq2seq example: Factored our model and input pipeline code.
- Added Optax update guide and deprecated `flax.optim`.
- Added `sep` argument to `flax.traverse_util.flatten_dict()`.
- Implemented Sequential module, in `flax.linen.combinators`.

0.4.0

------
Breaking changes:
- flax.deprecated.nn is removed. Please pin to flax==0.3.6 if you are still using it.
- PixelCNN++ example is removed. It was not working well on TPU.
- linen Normalization layers no longer downcast double and complex floats tofloat32
when computing the mean and variance.

New features:
- Added `flax.linen.custom_vjp` for custom derivatives inside a `Module`.
- Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes.

0.3.6

------
Breaking changes:
- Move `flax.nn` to `flax.deprecated.nn`.

New features:
- Add experimental checkpoint policy argument. See `flax.linen.checkpoint`
- Add lifted versions of jvp and vjp.
- Add lifted transformation for mapping variables. See `flax.linen.map_variables`.

0.3.5

------

Breaking changes:
- You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv.
Instead a type error is raised stating that
a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not
ambigious when the kernel rank is known.
- `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call.
- NamedTuples are no longer converted to tuples on assignment to a `linen.Module`.

New features:
- Flax internal stack frames are now removed from exception state traces.
- Added `flax.linen.nowrap` to decorate method that should not be transformed
because they are stateful.
- Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with `--jax_numpy_rank_promotion=raise`.

Bugfixes:
- linen Modules and dataclasses made with `flax.struct.dataclass` or `flax.struct.PyTreeNode` are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
- Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)).
- Mixed precision training with float16 now works correctly with the attention layers.
- auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes.

0.3.4

------

Possibly breaking changes:
- When calling `init` the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default.
- Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`.


Other changes:
- Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
- Added an NLP text classification example (on the SST-2 dataset) to
[`examples/sst2`](https://github.com/google/flax/tree/main/examples/sst2).
that uses a bidirectional LSTM (BiLSTM) to encode the input text.
- Added `flax.training.train_state` to simplify using Optax optimizers.
- `mutable` argument is now available on `Module.init` and `Module.init_with_outputs`
- Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
- Expose `dot_product_attention_weights`, allowing access to attention weights.
- `BatchNorm` instances will behave correctly during init when called multiple times.
- Added a more extensive "how to contribute" guide in `contributing.md`.
- Add proper cache behavior for [`lift.jit`](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.jit.html#flax.linen.jit),
fixing cache misses.
- Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
- Fix `linen.Module` for deep inheritance chains.
- Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
- Allow Flax lifted transforms to work on partially applied Modules.
- Make `MultiOptimizer` use `apply_gradient` instead of `apply_param_gradient`.

Page 5 of 7

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.