What's Changed
* Changing the version in the citation text in the README. by copybara-service in https://github.com/deepmind/kfac-jax/pull/29
* Adding attributes for the number of training and evaluation devices. by copybara-service in https://github.com/deepmind/kfac-jax/pull/31
* Adding some methods to ImplicitExactCurvature by copybara-service in https://github.com/deepmind/kfac-jax/pull/32
* Adding "put_stop_grad_on_loss_factor" argument to 'multiply_fisher_factor'. by copybara-service in https://github.com/deepmind/kfac-jax/pull/36
* Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d]. by copybara-service in https://github.com/deepmind/kfac-jax/pull/33
* * Changing `jax.tree_map` -> `jax.tree_util.tree_map` and related due to recent deprecation. by copybara-service in https://github.com/deepmind/kfac-jax/pull/37
* * Removed unused precedence argument from GraphPattern. by copybara-service in https://github.com/deepmind/kfac-jax/pull/38
* Fix a small bug where we don't check in the jaxpr constvars. by copybara-service in https://github.com/deepmind/kfac-jax/pull/39
* * Adding an `estimator` attribute to the optimizer. by copybara-service in https://github.com/deepmind/kfac-jax/pull/34
* Updating the docs to correctly refer to `update_cache`. by copybara-service in https://github.com/deepmind/kfac-jax/pull/40
* Compare with slightly less numerical precision. by copybara-service in https://github.com/deepmind/kfac-jax/pull/41
* * Revamping the graph matching code to be able to detect layers and register tag in arbitrary higher-order Jax primitives. by copybara-service in https://github.com/deepmind/kfac-jax/pull/42
* Revising docstring for optimizer class. Now contains missing details about value_and_grad_func. by copybara-service in https://github.com/deepmind/kfac-jax/pull/43
* Internal change. by copybara-service in https://github.com/deepmind/kfac-jax/pull/44
* * Make LossTag to return only the parameter dependent arrays. by copybara-service in https://github.com/deepmind/kfac-jax/pull/46
* * Improving LossTags to be able to deal correctly with None arguments, by passing in argument names. by copybara-service in https://github.com/deepmind/kfac-jax/pull/47
* Minor fix to a bug introduced on previous commit. by copybara-service in https://github.com/deepmind/kfac-jax/pull/48
* - Correcting issues with docstring for optimizer. by copybara-service in https://github.com/deepmind/kfac-jax/pull/45
* Fixing a bug in the graph matcher introduced in a recent CL. by copybara-service in https://github.com/deepmind/kfac-jax/pull/49
* Removing unneeded jax.jit in get_mean and get_sum. by copybara-service in https://github.com/deepmind/kfac-jax/pull/50
* - Adding per-parameter norm stats to optimizer by copybara-service in https://github.com/deepmind/kfac-jax/pull/51
* Allowing the pi-adjusted psd inverse to accept diagonal factors. by copybara-service in https://github.com/deepmind/kfac-jax/pull/55
* Fixing wrong type annotation of pmap_axis_name. by copybara-service in https://github.com/deepmind/kfac-jax/pull/56
* Adding optional offloading of `eigh` computation to the host because of a bug in CUDA 11.7.0 cuSOLVER library. by copybara-service in https://github.com/deepmind/kfac-jax/pull/57
**Full Changelog**: https://github.com/deepmind/kfac-jax/compare/v0.0.2...v0.0.3