Major Changes
- Experimental MuData support for {class}`~scvi.model.TOTALVI` via the method
{meth}`~scvi.model.TOTALVI.setup_mudata`. For several of the existing `AnnDataField` classes,
there is now a MuData counterpart with an additional `mod_key` argument used to indicate the
modality where the data lives (e.g. {class}`~scvi.data.fields.LayerField` to
{class}`~scvi.data.fields.MuDataLayerField`). These modified classes are simply wrapped
versions of the original `AnnDataField` code via the new
{class}`scvi.data.fields.MuDataWrapper` method [1474].
- Modification of the {meth}`~scvi.module.VAE.generative` method's outputs to return prior and
likelihood properties as {class}`~torch.distributions.distribution.Distribution` objects.
Concerned modules are {class}`~scvi.module.AmortizedLDAPyroModule`, {class}`AutoZIVAE`,
{class}`~scvi.module.MULTIVAE`, {class}`~scvi.module.PEAKVAE`, {class}`~scvi.module.TOTALVAE`,
{class}`~scvi.module.SCANVAE`, {class}`~scvi.module.VAE`, and {class}`~scvi.module.VAEC`. This
allows facilitating the manipulation of these distributions for model training and inference
[1356].
- Major changes to Jax support for scvi-tools models to generalize beyond
{class}`~scvi.model.JaxSCVI`. Support for Jax remains experimental and is subject to breaking
changes:
- Consistent module interface for Flax modules (Jax-backed) via
{class}`~scvi.module.base.JaxModuleWrapper`, such that they are compatible with the
existing {class}`~scvi.model.base.BaseModelClass` [1506].
- {class}`~scvi.train.JaxTrainingPlan` now leverages Pytorch Lightning to factor out
Jax-specific training loop implementation [1506].
- Enable basic device management in Jax-backed modules [1585].
Minor changes
- Add {meth}`~scvi.module.base.PyroBaseModuleClass.on_load` callback which is called on
{meth}`~scvi.model.base.BaseModuleClass.load` prior to loading the module state dict [1542].
- Refactor metrics code and use {class}`~torchmetrics.MetricCollection` to update metrics in bulk
[1529].
- Add `max_kl_weight` and `min_kl_weight` to {class}`~scvi.train.TrainingPlan` [1595].
- Add a warning to {class}`~scvi.model.base.UnsupervisedTrainingMixin` that is raised if
`max_kl_weight` is not reached during training [1595].
Breaking changes
- Any methods relying on the output of `inference` and `generative` from existing scvi-tools models
(e.g. {class}`~scvi.model.SCVI`, {class}`~scvi.model.SCANVI`) will need to be modified to
accept `torch.Distribution` objects rather than tensors for each parameter (e.g. `px_m`,
`px_v`) [1356].
- The signature of {meth}`~scvi.train.TrainingPlan.compute_and_log_metrics` has changed to support
the use of {class}`~torchmetrics.MetricCollection`. The typical modification required will look
like changing `self.compute_and_log_metrics(scvi_loss, self.elbo_train)` to
`self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")`. The same is necessary
for validation metrics except with `self.val_metrics` and the mode `"validation"` [1529].
Bug Fixes
- Fix issue with {meth}`~scvi.model.SCVI.get_normalized_expression` with multiple samples and
additional continuous covariates. This bug originated from {meth}`~scvi.module.VAE.generative`
failing to match the dimensions of the continuous covariates with the input when `n_samples>1`
in {meth}`~scvi.module.VAE.inference` in multiple module classes [1548].
- Add support for padding layers in {meth}`~scvi.model.SCVI.prepare_query_anndata` which is
necessary to run {meth}`~scvi.model.SCVI.load_query_data` for a model setup with a layer
instead of X [1575].
Contributors
- [jjhong922]
- [adamgayoso]
- [PierreBoyeau]
- [RK900]
- [FlorianBarkmann]