Changed
- **Remove einx dependency in compiled code:** The code for a traced function now directly imports and uses the namespace
of the backend (e.g. `import torch`). For example:
python
>>> print(einx.dot("b q (h c), b k (h c) -> b q k h", x, y, h=16, graph=True))
import torch
def op0(i0, i1):
x0 = torch.reshape(i0, (16, 768, 16, 64))
x1 = torch.reshape(i1, (16, 768, 16, 64))
x2 = torch.einsum("abcd,aecd->abec", x0, x1)
return x2
In most cases, compiled functions now contain no reference to other einx code.
- **Improve handling of Python scalars:** (see https://github.com/fferflo/einx/issues/7) einx now only converts `int`, `float` and `bool` to tensor
objects (e.g. via `torch.asarray`) if the backend function that is called does not support Python scalars (previously all inputs were converted
to tensor objects). When using PyTorch, the `device` argument will be used to place the constructed tensor on the correct
device.<br>For example, `torch.add` supports Python scalars
python
>>> print(einx.add("a,", x, 1, graph=True))
import torch
def op0(i0, i1):
x0 = torch.add(i0, i1)
return x0
while `torch.maximum` does not:
python
>>> print(einx.maximum("a,", x, 1, graph=True))
import torch
def op0(i0, i1):
x0 = torch.asarray(i1, device=i0.device)
x1 = torch.maximum(i0, x0)
return x1
- Run unit tests for PyTorch and Jax also on the GPU (if it is available).
- Run unit tests also with `jax.jit` and `torch.compile`.
Fixed
- Add workarounds for issues with `torch.compile`: https://github.com/pytorch/pytorch/issues/94674 and https://github.com/pytorch/pytorch/issues/124269