What's Changed
* Expose `normalize_probabilities` as a good normalization for `SoftmaxLoss`.
* Remove use of `initial` argument to `jax.nn.softmax` and `jax.nn.log_softmax`
* Drop python 3.8 checks and add python 3.11 checks.
* Changes in lambda weights to reduce boilerplate and add new options.
* Fix pytype and clean up types across codebase.
* Minor typo fixes in documentation.
**Full Changelog**: https://github.com/google/rax/compare/v0.3.0...v0.4.0