Textbrewer

Latest version: v0.2.1.post1

Safety actively analyzes 682387 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 2

0.2.1.post1

New Features

* **More flexible distillation**: Supports feeding different batches to the student and teacher. It means the batches for the student and teacher no longer need to be the same. It can be used for distilling models with different vocabularies (e.g., from RoBERTa to BERT). See the documentation for details.

* **Faster distillation**: Users can pre-compute and cache the teacher outputs, then feed the cache to the distiller to save teacher's forward pass time. See the documentation for details.

Improvements

* `MultiTaskDistiller` now is the subclass of `GeneralDistiller` and supports intermediate feature matching loss.
* Tensorboard now records more detailed losses (KD loss, hard label loss, matching losses...).
* `pkd_loss` now accepts tensors of shape (batch_size, length,hidden_size) or (batch_size,hidden_size). In the latter case, the loss is computed directly on the input tensors, without taking the hidden states on the first position.

0.2.0.1

Bug Fixes

* Fixed bugs in `MultiTaskDistiller`.
* Fixed the endless training loop when training for `num_steps`. Now distillers will stop correctly.

0.2.0

New Features

* Now supports distributed data-parallel training with `torch.nn.parallel.DistributedDataParallel` ! You can pass `local_rank` to the `TrainingConfig` to setup for the distributed training. The detailed usage of `DistributedDataParallel` can be found at the PyTorch docs.

* We also added an example (Chinese NER task) to demonstrate how to use TextBrewer with distributed data-parallel training.

0.1.10

New Features

* Now supports mixed precision training with Apex! Just set `fp16` to `True` in `TrainingConfig`. See the documentation of `TrainingConfig` for detail.
* Added `data_parallel` option in `TrainingConfig` to enable data parallel training within TextBrewer.

0.1.9

New Features

* Added an option `is_caching_logits` to `DistillationConfig`. If `is_caching_logits` is True, the distiller will cache the batches and the output logits of the teacher model, so that those logits will only be calcuated once. It will speed up the distillation process. This feature is **only available** for `BasicDistiller` and `MultiTeacherDistiller`. **Be caution of setting it to True on large datasets, since it will store the batches and logits into the memory.**

Improvements

* Added new argument `max_grad_norm` to distillers' `train` method. It sets the strength of gradient clipping. Default -1, i.e., no gradient clipping.
* Added new arguments `scheduler_class` and `scheduler_args` to distillers' `train` method. The old `scheduler` may cause convergence problem and is deprecated in favor of `scheduler_class` and `scheduler_args`. See the documentation for details.
* Removed `print` in the`display_paramters`. Now it won't print the statistics directly to the screen.

Bug Fixes

* Fixed wrong call of zero_grad().

0.1.8

Improvements:

* `TrainingConfig.log_dir` can be set to `None` to disable TensorBoard.
* Added an attribute `print_freq` to the distiller to control the frequency of logging.
* Added a new argument `num_steps` to the `train` method of the distiller. If `num_steps` is specified, the distiller will ignore `num_epochs` and allow an unknown-size dataloader (i.e., which has no `__len__` attribute).
* Added a new argument `batch_postprocessor` to the `train` method of the distiller to allow post-processing of batches.

Page 1 of 2

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.