New Features
Thanks to mlopezantequera for adding the following features!
Testers: allow any combination of query and reference sets (250)
To evaluate different combinations of query and reference sets, use the splits_to_eval argument for tester.test().
For example, let's say your dataset_dict has two keys: "dataset_a" and "train".
- The default splits_to_eval = None is equivalent to:
python
splits_to_eval = [('dataset_a', ['dataset_a']), ('train', ['train'])]
- dataset_a as the query, and train as the reference:
python
splits_to_eval = [('dataset_a', ['train'])]
- dataset_a as the query, and dataset_a + train as the reference:
python
splits_to_eval = [('dataset_a', ['dataset_a', 'train'])]
Then pass splits_to_eval to tester.test:
python
tester.test(dataset_dict, epoch, model, splits_to_eval = splits_to_eval)
Note that this new feature makes the old reference_set init argument obsolete, so reference_set has been removed.
AccuracyCalculator: allow arbitrary label comparion functions (254)
AccuracyCalculator now has an optional init argument, label_comparison_fn, which is a function that compares two numpy arrays of labels and returns a boolean array. The default is numpy.equal. If a custom function is used, then you must exclude clustering based metrics ("NMI" and "AMI"). The following is an example of a custom function for two-dimensional labels. It returns True if the 0th column matches, and the 1st column does **not** match:
python
def example_label_comparison_fn(x, y):
return (x[:, 0] == y[:, 0]) & (x[:, 1] != y[:, 1])
AccuracyCalculator(exclude=("NMI", "AMI"),
label_comparison_fn=example_label_comparison_fn)
Other Changes
- BaseTrainer and BaseTester now take in an optional dtype argument. This is the type that the dataset output will be converted to, e.g. torch.float16. If set to the default value of None, then no type casting will be done.
- Removed self.dim_reduced_embeddings from BaseTester and the associated code in HookContainer, due to lack of use.
- tester.test() now returns all_accuracies, whereas before, it returned nothing and you'd have to access all_accuracies either through the end_of_testing_hook or by accessing tester.all_accuracies.
- tester.embeddings_and_labels is deleted at the end of tester.test() to free up memory.