New features
SelfSupervisedLoss
You don't have to create labels for self-supervised learning anymore:
python
from pytorch_metric_learning.losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss(TripletMarginLoss())
embeddings = model(data)
augmented = model(augmented_data)
loss = loss_func(embeddings, augmented)
Thanks cwkeam!
API changes
AccuracyCalculator.get_accuracy
The order and naming of arguments has changed.
Before:
python
get_accuracy(
query,
reference,
query_labels,
reference_labels,
embeddings_come_from_same_source=False
)
Now:
python
get_accuracy(
query,
query_labels,
reference=None,
reference_labels=None
ref_includes_query=False
)
The benefits of this change are:
- if `query is reference`, then you only need to pass in `query, query_labels`
- `ref_includes_query` is shorter and clearer in meaning than `embeddings_come_from_same_source`
Some example usage of the new format:
python
Accuracy of a query set, where the query set is also the reference set:
get_accuracy(query, query_labels)
Accuracy of a query set with a separate reference set:
get_accuracy(query, query_labels, ref, ref_labels)
Accuracy of a query set with a reference set that includes the query set:
get_accuracy(query, query_labels, ref, ref_labels, ref_includes_query=True)
`BaseMiner` instead of `BaseTupleMiner`
Miners must extend `BaseMiner` because `BaseTupleMiner` no longer exists
CrossBatchMemory's `enqueue_idx` is now `enqueue_mask`
Before, `enqueue_idx` specified the indices of `embeddings` that should be added to the memory bank.
Now, `enqueue_mask[i]` should be `True` if `embeddings[i]` should be added to the memory bank.
The benefit of this change is that it fixed an issue in distributed training.
Here's an example of the new usage:
python
enqueue the second half of a batch
enqueue_mask = torch.zeros(batch_size).bool()
enqueue_mask[batch_size/2:] = True
VICRegLoss requires keyword argument
Before:
python
loss_fn = VICRegLoss()
loss_fn(emb, ref_emb)
Now:
python
loss_fn = VICRegLoss()
loss_fn(emb, ref_emb=ref_emb)
The reason is that VICRegLoss now uses the `forward` method of `BaseMetricLossFunction`, to allow for possible generalizations in the future without causing more breaking changes.
BaseTrainer `mining_funcs` and `dataset` have swapped order
This is to allow `mining_funcs` to be optional.
Before if you didn't want to use miners:
python
MetricLossOnly(
models,
optimizers,
batch_size,
loss_funcs,
mining_funcs = {},
dataset = dataset,
)
Now:
python
MetricLossOnly(
models,
optimizers,
batch_size,
loss_funcs,
dataset,
)
Deletions
The following classes/functions were removed
- `losses.CentroidTripletLoss` (it contained a bug that I don't have time to figure out)
- `miners.BaseTupleMiner` (use `miners.BaseMiner` instead)
- `miners.BaseSubsetBatchMiner` (rarely used)
- `miners.MaximumLossMiner` (rarely used)
- `trainers.UnsupervisedEmbeddingsUsingAugmentations` (rarely used)
- `utils.common_functions.Identity` (use `torch.nn.Identity` instead)
Other minor changes
- VICRegLoss should now work with DistributedLossWrapper (https://github.com/KevinMusgrave/pytorch-metric-learning/issues/535)
- Dynamic recordable attribute names were removed (https://github.com/KevinMusgrave/pytorch-metric-learning/issues/436)
- AccuracyCalculator now returns NaN instead of 0 when none of the query labels appear in the reference set (https://github.com/KevinMusgrave/pytorch-metric-learning/issues/397)