allowing you to get rid of tensordict altogether when exporting your models:
python
from torch.export import export
model = Seq(
1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
model_export = export(model, args=(), kwargs={"x": x})
See our [new tutorial](https://pytorch.org/tensordict/main/tutorials/export.html) to learn more about this feature.
The library integration with the PT2 stack is also further improved by the introduction of [`CudaGraphModule`](https://pytorch.org/tensordict/main/reference/generated/tensordict.nn.CudaGraphModule.html),
which can be used to speed-up model execution under a certain set of assumptions; mainly that the inputs and outputs
are non-differentiable, that they are all tensors or constant and that the whole graph can be executed on cuda with
buffers of constant shape (ie, dynamic shape is not allowed).
We also introduce a new tutorial on [streaming tensordicts](https://pytorch.org/tensordict/main/tutorials/streamed_tensordict.html).
*Note*: The `aarch64` binaries are attached to these release notes and not available in PyPI at the moment.
Deprecations
* [Deprecate] Make calls to make_functional error 1034 by vmoens
* [Deprecation] Act warned deprecations for v0.6 1001 by vmoens
* [Refactor] make TD.get default to None, like dict (948) by vmoens
Features
* [Feature] Allow to specify log_prob_key in CompositeDistribution (961) by albertbou92
* [Feature] Better typing for tensorclass 983 by vmoens
* [Feature] Cudagraphs (986) by vmoens
* [Feature] Densify lazy tensordicts 955 by vmoens
* [Feature] Frozen tensorclass 984 by vmoens
* [Feature] Make NonTensorData a callable (939) by vmoens
* [Feature] NJT with lengths 1021 by vmoens
* [Feature] Non-blocking for consolidated TD 1020 by vmoens
* [Feature] Propagate `existsok` in memmap* methods 990 by vmoens
* [Feature] TD+NJT to(device) support 1022 by vmoens
* [Feature] TensorDict.record_stream 1016 by vmoens
* [Feature] Unify composite dist method signatures with other dists (981) Co-authored-by: Vincent Moens <vincentmoensgmail.com>^M
* [Feature] _foreach_copy_ for update_ 1032 by vmoens
* [Feature] `data_ptr()` method 1024 by vmoens
* [Feature] `inplace` arg in TDM constructor 992 by vmoens
* [Feature] `selected_out_keys` arg in TDS constructor 993 by vmoens
* [Feature] better sync and instantiation of cudagraphs (1013) by vmoens
* [Feature] callables for merge_tensordicts 1033 by vmoens
* [Feature] cat and stack_from_tensordict 1018 by vmoens
* [Feature] cat_tensors and stack_tensors 1017 by vmoens
* [Feature] from_struct_array and to_struct_array (938) by vmoens
* [Feature] give a `__name__` to TDModules 1045 by vmoens
* [Feature] param_count 1046 by vmoens
* [Feature] sorted keys, values and items 965 by vmoens
* [Feature] str2td 953 by vmoens
* [Feature] torch.export and onnx compatibility 991 by vmoens
Code improvements
* [Quality] Better error for mismatching TDs (964) by vmoens
* [Quality] Better type hints for `__init__` (1014) by vmoens
* [Quality] Expose private classmethods (982) by vmoens
* [Quality] Fewer recompiles with tensordict (1015) by vmoens
* [Quality] Type checks 976 by vmoens
* [Refactor, Tests] Move TestCudagraphs by vmoens
* [Refactor, Tests] Move TestCudagraphs 1007 by vmoens
* [Refactor] Make tensorclass work properly with pyright (1042) by maxim
* [Refactor] Update nn inline_inbuilt check 1029 by vmoens
* [Refactor] Use IntEnum for interaction types (989) by vmoens
* [Refactor] better AddStateIndependentNormalScale 1028 by vmoens
Fixes
* [BugFix] Add nullbyte in memmap files to make fbcode happy (943) by vmoens
* [BugFix] Add sync to cudagraph module (1026) by vmoens
* [BugFix] Another compiler fix for older pytorch 980 by vmoens
* [BugFix] Compatibility with non-tensor inputs in CudaGraphModule 1039 by vmoens
* [BugFix] Deserializing a consolidated TD reproduces a consolidated TD 1019 by vmoens
* [BugFix] Fix _foreach_copy_ for older versions of PT 1035 by vmoens
* [BugFix] Fix buffer identity in Params._apply (1027) by vmoens
* [BugFix] Fix key errors catch in del_ and related (949) by vmoens
* [BugFix] Fix number check in array parsing (np>=2 compatibility) 999 by vmoens
* [BugFix] Fix pre 2.1 _apply compatibility 1050 by vmoens
* [BugFix] Fix select in tensorclass (936) by vmoens
* [BugFix] Fix td device sync when error is raised 988 by vmoens
* [BugFix] Fix tree_leaves import for older versions of PT 995 by vmoens
* [BugFix] Fix vmap monkey patching 1009 by vmoens
* [BugFix] Make probabilistic sequential modules compatible with compile 1030 by vmoens
* [BugFix] Other dynamo fixes 977 by vmoens
* [BugFix] Propagate maybe_dense_stack in _stack 1036 by vmoens
* [BugFix] Regular swap_tensor for to_module in dynamo (963) by vmoens
* [BugFix] Remove ForkingPickler to account for change of API in torch.mp 998 by vmoens
* [BugFix] Remove forkingpickler (1049) by bhack
* [BugFix] Resilient deterministic_sample for CompositeDist 1000 by vmoens
* [BugFix] Simple syncs (942) by vmoens
* [BugFix] Softly revert get changes (950) by vmoens
* [BugFix] TDParams.to(device) works as nn.Module, not TDParams contained TD 1025 by vmoens
* [BugFix] Use separate streams for cudagraph warmup 1010 by vmoens
* [BugFix] dynamo compat refactors 975 by vmoens
* [BugFix] resilient _exclude_td_from_pytree 1038 by vmoens
* [BugFix] restrict usage of Buffers to non-batched, non-tracked tensors 979 by vmoens
Doc
* [Doc] Broken links in GETTING_STARTED.md (945) by vmoens
* [Doc] Fail-on-warning in sphinx 1005 by vmoens
* [Doc] Fix tutorials 1002 by vmoens
* [Doc] Refactor README and add GETTING_STARTED.md (944) by vmoens
* [Doc] Streaming tensordicts 956 by vmoens
* [Doc] export tutorial, TDM tuto refactoring 994 by vmoens
Performance
* [Performance] Faster `__setitem__` (985) by vmoens
* [Performance] Faster clone 1043 by vmoens
Not user facing
* [Benchmark] Benchmark H2D transfer 1044 by vmoens
* [CI, BugFix] Fix nightly build (941) by vmoens
* [CI] Add aarch64-linux wheels (987) by vmoens
* [CI] Fix versioning of h2d tests 1053 by vmoens
* [CI] Fix windows wheels 1006 by vmoens
* [CI] Upgrade 3.8 workflows (967) by vmoens
* [Minor, Format] Fix fbcode lint (940) by vmoens
* [Minor] Refactor is_dynamo_compiling for older torch versions (978) by vmoens
* [Setup] Correct read_file encoding in setup (962) by vmoens
* [Test] Keep a tight control over warnings (951) by vmoens
* [Test] Make h5py tests optional if no h5py installed (947) by vmoens
* [Test] Mark MP tests as slow (946) by vmoens
* [Test] Rename duplicated test 997 by vmoens
* [Test] Skip compile tests that require 2.5 for stable 996 by vmoens
* [Versioning] Versions for 0.6 (1052) by vmoens
New Contributors
* Mxbonn made their first contribution in https://github.com/pytorch/tensordict/pull/1042
* bhack made their first contribution in https://github.com/pytorch/tensordict/pull/1049
**Full Changelog**: https://github.com/pytorch/tensordict/compare/v0.5.0...v0.6.0
Co-authored-by: Vincent Moens <vmoensmeta.com> by albertbou92