Pytreeclass

Latest version: v0.9.2

Safety actively analyzes 681812 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 2

84.8

24.9

%timeit indexer[...].apply(imread) not parallel

0.9.2

Changes:

- change `threads_count` in `apply` parallel kwargs to `max_workers`

0.9.1

Additions:

- Add parallel mapping option in `AtIndexer`. This enables myriad of tasks, like reading a pytree of image file names.

python
benchmarking serial vs sequential image read
on mac m1 cpu with image of size 512x512x3
import pytreeclass as tc
from matplotlib.pyplot import imread
paths = ["lenna.png"] * 10
indexer = tc.AtIndexer(paths)
%timeit indexer[...].apply(imread,parallel=True) parallel

0.9

Breaking changes:
- To simplify the API the following will be removed:
1) `tree_repr_with_trace`
2) `tree_map_with_trace`
3) `tree_flatten_with_trace`
4) `tree_leaves_with_trace`

0.8

Additions:
- Add `on_getattr` in `field` to apply function on `__getattr__`

Breaking changes:
- Rename `callbacks` in `field` to `on_setattr` to match `attrs` and better reflect its functionality.



_These changes enable:_

1) stricter data validation on instance values, as in the following example:

<details>

`on_setattr` ensure the value is of certain type (e.g.integer) during _initialization_, and `on_getattr`, ensure the value is of certain type (e.g. integer) whenever its accessed.


python

import pytreeclass as tc
import jax

def assert_int(x):
assert isinstance(x, int), "must be an int"
return x

tc.autoinit
class Tree(tc.TreeClass):
a: int = tc.field(on_getattr=[assert_int], on_setattr=[assert_int])

def __call__(self, x):
enusre `a` is an int before using it in computation by calling `assert_int`
a: int = self.a
return a + x

tree = Tree(a=1)
print(tree(1.0)) 2.0
tree = jax.tree_map(lambda x: x + 0.0, tree) make `a` a float
tree(1.0) AssertionError: must be an int


</details>

2) Frozen field without using `tree_mask`/`tree_unmask`

<details>

The following shows a pattern where the value is frozen on `__setattr__` and unfrozen whenever accessed, this ensures that `jax` transformation does not see the value. the following example showcase this functionality

python
import pytreeclass as tc
import jax

tc.autoinit
class Tree(tc.TreeClass):
frozen_a : int = tc.field(on_getattr=[tc.unfreeze], on_setattr=[tc.freeze])

def __call__(self, x):
return self.frozen_a + x

tree = Tree(frozen_a=1) 1 is non-jaxtype
can be used in jax transformations

jax.jit
def f(tree, x):
return tree(x)

f(tree, 1.0) 2.0

grads = jax.grad(f)(tree, 1.0) Tree(frozen_a=1)


Compared with other libraies that implements `static_field`, this pattern has *lower* overhead and does not alter `tree_flatten`/`tree_unflatten` methods of the tree.


</details>

3) Easier way to create a buffer (non-trainable array)

<details>

Just use `jax.lax.stop_gradient` in `on_getattr`

python
import pytreeclass as tc
import jax
import jax.numpy as jnp

def assert_array(x):
assert isinstance(x, jax.Array)
return x

tc.autoinit
class Tree(tc.TreeClass):
buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
def __call__(self, x):
return self.buffer**x

tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0) Array([1., 4., 9.], dtype=float32)
jax.jit
def f(tree, x):
return jnp.sum(tree(x))

f(tree, 1.0) Array([1., 2., 3.], dtype=float32)
print(jax.grad(f)(tree, 1.0)) Tree(buffer=[0. 0. 0.])


</details>

Page 1 of 2

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.