Additions:
- Add `on_getattr` in `field` to apply function on `__getattr__`
Breaking changes:
- Rename `callbacks` in `field` to `on_setattr` to match `attrs` and better reflect its functionality.
_These changes enable:_
1) stricter data validation on instance values, as in the following example:
<details>
`on_setattr` ensure the value is of certain type (e.g.integer) during _initialization_, and `on_getattr`, ensure the value is of certain type (e.g. integer) whenever its accessed.
python
import pytreeclass as tc
import jax
def assert_int(x):
assert isinstance(x, int), "must be an int"
return x
tc.autoinit
class Tree(tc.TreeClass):
a: int = tc.field(on_getattr=[assert_int], on_setattr=[assert_int])
def __call__(self, x):
enusre `a` is an int before using it in computation by calling `assert_int`
a: int = self.a
return a + x
tree = Tree(a=1)
print(tree(1.0)) 2.0
tree = jax.tree_map(lambda x: x + 0.0, tree) make `a` a float
tree(1.0) AssertionError: must be an int
</details>
2) Frozen field without using `tree_mask`/`tree_unmask`
<details>
The following shows a pattern where the value is frozen on `__setattr__` and unfrozen whenever accessed, this ensures that `jax` transformation does not see the value. the following example showcase this functionality
python
import pytreeclass as tc
import jax
tc.autoinit
class Tree(tc.TreeClass):
frozen_a : int = tc.field(on_getattr=[tc.unfreeze], on_setattr=[tc.freeze])
def __call__(self, x):
return self.frozen_a + x
tree = Tree(frozen_a=1) 1 is non-jaxtype
can be used in jax transformations
jax.jit
def f(tree, x):
return tree(x)
f(tree, 1.0) 2.0
grads = jax.grad(f)(tree, 1.0) Tree(frozen_a=1)
Compared with other libraies that implements `static_field`, this pattern has *lower* overhead and does not alter `tree_flatten`/`tree_unflatten` methods of the tree.
</details>
3) Easier way to create a buffer (non-trainable array)
<details>
Just use `jax.lax.stop_gradient` in `on_getattr`
python
import pytreeclass as tc
import jax
import jax.numpy as jnp
def assert_array(x):
assert isinstance(x, jax.Array)
return x
tc.autoinit
class Tree(tc.TreeClass):
buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
def __call__(self, x):
return self.buffer**x
tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0) Array([1., 4., 9.], dtype=float32)
jax.jit
def f(tree, x):
return jnp.sum(tree(x))
f(tree, 1.0) Array([1., 2., 3.], dtype=float32)
print(jax.grad(f)(tree, 1.0)) Tree(buffer=[0. 0. 0.])
</details>