What's Changed
This refactor was largely inspired by me wanting to better keep track of which parameter is doing what! Thanks to many of the operations in jax working on arbitrary pytrees, I'm pursuing tracking parameters as dictionaries, and have adjusted the fit-level logic to now assume this. CI has been updated to include a different kind of HistFactory model structure that looks much more like what I'm building out elsewhere!
The assumptions for this library to work are now just this:
python3
import equinox as eqx turn our class into a PyTree
class Model(eqx.Module):
any attributes here that are not valid jax types (e.g. str) need to be declared like:
name: str = eqx.field(static=True)
def logpdf(self, pars: dict[str, ArrayLike], data: Array) -> float | Array: ...
def expected_data(self, pars: dict[str, ArrayLike]) -> Array: ...
In particular, note that logpdf returns a float (or scalar array) -- this means that pyhf models, if they are ever compatible again, would need to patch in `lambda pars, data: model.logpdf(pars, data)[0]`.
Python 3.8 has been officially dropped, keeping in-step with libraries that this depends on (e.g. equinox).
* remove intel macs from CI since jaxopt LBFGS-B does not correctly converge by phinate in https://github.com/gradhep/relaxed/pull/59
* Refactor that assumes parameters are in a key-value mapping by phinate in https://github.com/gradhep/relaxed/pull/61
**Full Changelog**: https://github.com/gradhep/relaxed/compare/v0.3.0...v0.4.0