- Allow nested mutations using `.at[method](*args, **kwargs)`.
After the change, inner methods can mutate **_copied_** new instances at any level not just the top level.
a motivation for this is to experiment with _lazy initialization scheme_, where inner layers need to mutate their inner state. see the example below for `flax`-like lazy initialization as descriped [here](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit#slide=id.g8d686e6bf0_1_57)
<details>
python
import pytreeclass as tc
import jax.random as jr
from typing import Any
import jax
import jax.numpy as jnp
from typing import Callable, TypeVar
T = TypeVar("T")
tc.autoinit
class LazyLinear(tc.TreeClass):
outdim: int
weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal()
bias_init: Callable[..., T] = jax.nn.initializers.zeros
def param(self, name: str, init_func: Callable[..., T], *args) -> T:
if name not in vars(self):
setattr(self, name, init_func(*args))
return vars(self)[name]
def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim))
y = x w
if self.bias_init is not None:
b = self.param("bias", self.bias_init, key, (self.outdim,))
return y + b
return y
tc.autoinit
class StackedLinear(tc.TreeClass):
l1: LazyLinear = LazyLinear(outdim=10)
l2: LazyLinear = LazyLinear(outdim=1)
def call(self, x: jax.Array):
return self.l2(jax.nn.relu(self.l1(x)))
lazy_layer = StackedLinear()
print(repr(lazy_layer))
StackedLinear(
l1=LazyLinear(
outdim=10,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype)
),
l2=LazyLinear(
outdim=1,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype)
)
)
_, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
materialized_layer
StackedLinear(
l1=LazyLinear(
outdim=10,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype),
weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]),
bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
),
l2=LazyLinear(
outdim=1,
weight_init=init(key, shape, dtype),
bias_init=zeros(key, shape, dtype),
weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
)
)
materialized_layer(jnp.ones((1, 5)))
Array([[0.16712935]], dtype=float32)
</details>