Changes
1. User-provided `re.Pattern` is used to match keys with regex pattern instead of using `RegexKey`
<details>
Example:
python
import pytreeclass as tc
import re
tree = {"l1":1, "l2":2, "b":3}
tree = tc.AtIndexer(tree)
tree.at[re.compile("l.*")].get()
{'b': None, 'l1': 1, 'l2': 2}
</details>
Deprecations
1. `RegexKey` is deprecated. use `re` compiled patterns instead.
2. `tree_indent` is deprecated. use `tree_diagram(tree).replace(...)` to replace the edges characters with spaces.
New features
1. Add `tree_mask`, `tree_unmask` to freeze/unfreeze tree leaves based on a callable/boolean pytree mask. defaults to masking non-inexact types by frozen wrapper.
<details>
Example: Pass non-`jax` types through `jax` transformation without error.
python
pass non-differentiable values to `jax.grad`
import pytreeclass as tc
import jax
jax.grad
def square(tree):
tree = tc.tree_unmask(tree)
return tree[0]**2
tree = (1., 2) contains a non-differentiable node
square(tc.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), 2)
</details>
2. Support extending match keys by adding abstract base class `BaseKey`. check docstring for example
3. Support multi-index by any acceptable form. e.g. boolean pytree, key, int, or `BaseKey` instance
<details>
Example:
python
import pytreeclass as tc
tree = {"l1":1, "l2":2, "b":3}
tree = tc.AtIndexer(tree)
tree.at["l1","l2"].get()
{'b': None, 'l1': 1, 'l2': 2}
</details>
4. add `scan` to `AtIndexer` to carry a state while applying a function.
<details>
Example:
python
import pytreeclass as tc
def scan_func(leaf, state):
increase the state by 1 for each function call
return leaf**2, state+1
tree = {"l1": 1, "l2": 2, "b": 3}
tree = tc.AtIndexer(tree)
tree, state = tree.at["l1", "l2"].scan(scan_func, 0)
state
2
tree
{'b': 3, 'l1': 1, 'l2': 4}
</details>
5. `tree_summary` improvements.
- Add size column to `tree_summary`.
- add `def_count` to dispatch count rule for type.
- add `def_size` to dispatch size rule for type.
- add `def_type` to dispatch type display.
<details>
Example:
python
import pytreeclass as tc
import jax.numpy as jnp
x = jnp.ones((5, 5))
print(tc.tree_summary([1, 2, 3, x]))
┌────┬────────┬─────┬───────┐
│Name│Type │Count│Size │
├────┼────────┼─────┼───────┤
│[0] │int │1 │ │
├────┼────────┼─────┼───────┤
│[1] │int │1 │ │
├────┼────────┼─────┼───────┤
│[2] │int │1 │ │
├────┼────────┼─────┼───────┤
│[3] │f32[5,5]│25 │100.00B│
├────┼────────┼─────┼───────┤
│Σ │list │28 │100.00B│
└────┴────────┴─────┴───────┘
make list display its number of elements
in the type row
tc.tree_summary.def_type(list)
def _(_: list) -> str:
return f"List[{len(_)}]"
print(tc.tree_summary([1, 2, 3, x]))
┌────┬────────┬─────┬───────┐
│Name│Type │Count│Size │
├────┼────────┼─────┼───────┤
│[0] │int │1 │ │
├────┼────────┼─────┼───────┤
│[1] │int │1 │ │
├────┼────────┼─────┼───────┤
│[2] │int │1 │ │
├────┼────────┼─────┼───────┤
│[3] │f32[5,5]│25 │100.00B│
├────┼────────┼─────┼───────┤
│Σ │List[4] │28 │100.00B│
└────┴────────┴─────┴───────┘
</details>
6. Export pytrees to dot language using `tree_graph`
<details>
python
define custom style for a node by dispatching on the value
the defined function should return a dict of attributes
that will be passed to graphviz.
import pytreeclass as tc
tree = [1, 2, dict(a=3)]
tc.tree_graph.def_nodestyle(list)
def _(_) -> dict[str, str]:
return dict(shape="circle", style="filled", fillcolor="lightblue")
dot_graph = graphviz.Source(tc.tree_graph(tree))
dot_graph

7. Add variable position arguments and variable keyword arguments to `tc.field` `kind`
<details>
python
import pytreeclass as tc
class Tree(tc.TreeClass):
a: int = tc.field(kind="VAR_POS")
b: int = tc.field(kind="POS_ONLY")
c: int = tc.field(kind="VAR_KW")
d: int
e: int = tc.field(kind="KW_ONLY")
Tree.__init__
<function __main__.Tree.__init__(self, b: int, /, d: int, *a: int, e: int, **c: int) -> None>
</details>
This release introduces lots of `functools.singledispatch` usage, to enable the greater customization.
- `{freeze,unfreeze,is_nondiff}.def_type` to define how to `freeze` a type, how to unfreeze it and whether it is considred nondiff or not. these rules are used by these functions and `tree_mask`/`tree_unmask`.
- `tree_graph.def_nodestyle`, `tree_summary.def_{count,type,size}` for pretty printing customization
- `BaseKey.def_alias` to define type alias usage inside `AtIndexer`/`.at`
- Internally, most of the pretty printing is using dispatching to define repr/str rules for each instance type.