Pytreeclass

Latest version: v0.11.0

Safety actively analyzes 723625 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 2 of 3

0.9

Breaking changes:

- To simplify the API the following will be removed:

1. `tree_repr_with_trace`
2. `tree_map_with_trace`
3. `tree_flatten_with_trace`
4. `tree_leaves_with_trace`

A variant of these will be included in the common recipes.

0.8

Additions:

- Add `on_getattr` in `field` to apply function on `__getattr__`

Breaking changes:

- Rename `callbacks` in `field` to `on_setattr` to match `attrs` and better reflect its functionality.

_These changes enable:_

1. stricter data validation on instance values, as in the following example:

<details>

`on_setattr` ensure the value is of certain type (e.g.integer) during _initialization_, and `on_getattr`, ensure the value is of certain type (e.g. integer) whenever its accessed.

python

import pytreeclass as tc
import jax

def assert_int(x):
assert isinstance(x, int), "must be an int"
return x

tc.autoinit
class Tree(tc.TreeClass):
a: int = tc.field(on_getattr=[assert_int], on_setattr=[assert_int])

def __call__(self, x):
enusre `a` is an int before using it in computation by calling `assert_int`
a: int = self.a
return a + x

tree = Tree(a=1)
print(tree(1.0)) 2.0
tree = jax.tree_map(lambda x: x + 0.0, tree) make `a` a float
tree(1.0) AssertionError: must be an int


</details>

2. Frozen field without using `tree_mask`/`tree_unmask`

<details>

The following shows a pattern where the value is frozen on `__setattr__` and unfrozen whenever accessed, this ensures that `jax` transformation does not see the value. the following example showcase this functionality

python
import pytreeclass as tc
import jax

tc.autoinit
class Tree(tc.TreeClass):
frozen_a : int = tc.field(on_getattr=[tc.unfreeze], on_setattr=[tc.freeze])

def __call__(self, x):
return self.frozen_a + x

tree = Tree(frozen_a=1) 1 is non-jaxtype
can be used in jax transformations

jax.jit
def f(tree, x):
return tree(x)

f(tree, 1.0) 2.0

grads = jax.grad(f)(tree, 1.0) Tree(frozen_a=1)


Compared with other libraies that implements `static_field`, this pattern has _lower_ overhead and does not alter `tree_flatten`/`tree_unflatten` methods of the tree.

</details>

3. Easier way to create a buffer (non-trainable array)

<details>

Just use `jax.lax.stop_gradient` in `on_getattr`

python
import pytreeclass as tc
import jax
import jax.numpy as jnp

def assert_array(x):
assert isinstance(x, jax.Array)
return x

tc.autoinit
class Tree(tc.TreeClass):
buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
def __call__(self, x):
return self.buffer**x

tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0) Array([1., 4., 9.], dtype=float32)
jax.jit
def f(tree, x):
return jnp.sum(tree(x))

f(tree, 1.0) Array([1., 2., 3.], dtype=float32)
print(jax.grad(f)(tree, 1.0)) Tree(buffer=[0. 0. 0.])


</details>

0.7

- Remove `.at` as an alias for `__getitem__` when specifying a path entry for where in `AtIndexer`. This leads to less verbose style.

Example:

python

>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> indexer = tc.AtIndexer(tree)

>>> Before:
>>> style 1 (with at):
>>> indexer.at["level1_0"].at["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> style 2 (no at):
>>> indexer["level1_0"]["level2_0", "level2_1"].get()

>>> After
>>> only style 2 is valid
>>> indexer["level1_0"]["level2_0", "level2_1"].get()


diff
- tree = indexer.at["level1_0"].at["level2_0", "level2_1"].get()
+ tree = indexer["level1_0"]["level2_0", "level2_1"].get()


For `TreeClass`

`at` is specified _once_ for each change

diff
tc.autoinit
class Tree(tc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])

def __call__(self, x):
return self.a + self.b[0] + self.c + x


tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
- .at["b"].at[0].set(10.0)\
+ .at["b"][0].set(10.0)\
.at[mask].set(100.0)

0.6.0post0

- using `tree_{repr,str}` with an object containing cyclic references will raise `RecursionError` instead of displaying cyclicref.

0.6.0

- Allow nested mutations using `.at[method](*args, **kwargs)`.
After the change, inner methods can mutate **_copied_** new instances at any level not just the top level.
a motivation for this is to experiment with _lazy initialization scheme_, where inner layers need to mutate their inner state. see the example below for `flax`-like lazy initialization as descriped [here](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit#slide=id.g8d686e6bf0_1_57)

<details>

python

import pytreeclass as tc
import jax.random as jr
from typing import Any
import jax
import jax.numpy as jnp
from typing import Callable, TypeVar

T = TypeVar("T")

tc.autoinit
class LazyLinear(tc.TreeClass):
outdim: int
weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal()
bias_init: Callable[..., T] = jax.nn.initializers.zeros

def param(self, name: str, init_func: Callable[..., T], *args) -> T:
if name not in vars(self):
setattr(self, name, init_func(*args))
return vars(self)[name]

def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim))
y = x w
if self.bias_init is not None:
b = self.param("bias", self.bias_init, key, (self.outdim,))
return y + b
return y


tc.autoinit
class StackedLinear(tc.TreeClass):
l1: LazyLinear = LazyLinear(outdim=10)
l2: LazyLinear = LazyLinear(outdim=1)

def call(self, x: jax.Array):
return self.l2(jax.nn.relu(self.l1(x)))

lazy_layer = StackedLinear()
print(repr(lazy_layer))
StackedLinear(
l1=LazyLinear(
outdim=10,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype)
),
l2=LazyLinear(
outdim=1,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype)
)
)

_, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
materialized_layer
StackedLinear(
l1=LazyLinear(
outdim=10,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype),
weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]),
bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
),
l2=LazyLinear(
outdim=1,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype),
weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
)
)

materialized_layer(jnp.ones((1, 5)))
Array([[0.16712935]], dtype=float32)


</details>

0.5post0

- fix `__init_subclass__`. not accepting arguments. this bug is introduced since `v0.5`

Page 2 of 3

© 2025 Safety CLI Cybersecurity Inc. All Rights Reserved.