Algorithms
New algorithms are introduced in this version.
- Critic Regularized Regression (CRR)
- https://arxiv.org/abs/2006.15134
- Model-based Offline Policy Optimization (MOPO)
- https://arxiv.org/abs/2005.13239
Model-based RL
Previously, model-based RL has been supported. The model-based specific logic was implemented in `dynamics` side. This approach enabled us to combine model-based algorithms with arbitrary model-free algorithms. However, this requires complex designs to implement the recent model-based RL. So, the dynamics interface was refactored and the MOPO is the first algorithm to show how d3rlpy supports model-based RL algorithms.
py
train dynamics model
from d3rlpy.datasets import get_pendulum
from d3rlpy.dynamics import ProbabilisticEnsembleDynamics
from d3rlpy.metrics.scorer import dynamics_observation_prediction_error_scorer
from d3rlpy.metrics.scorer import dynamics_reward_prediction_error_scorer
from d3rlpy.metrics.scorer import dynamics_prediction_variance_scorer
from sklearn.model_selection import train_test_split
dataset, _ = get_pendulum()
train_episodes, test_episodes = train_test_split(dataset)
dynamics = d3rlpy.dynamics.ProbabilisticEnsembleDynamics(learning_rate=1e-4, use_gpu=True)
dynamics.fit(train_episodes,
eval_episodes=test_episodes,
n_epochs=100,
scorers={
'observation_error': dynamics_observation_prediction_error_scorer,
'reward_error': dynamics_reward_prediction_error_scorer,
'variance': dynamics_prediction_variance_scorer,
})
train Model-based RL algorithm
from d3rlpy.algos import MOPO
give mopo as generator argument.
mopo = MOPO(dynamics=dynamics)
mopo.fit(dataset, n_steps=100000)
enhancements
- `fitter` method has been implemented (thanks jamartinh )
- `tensorboard_dir` repleces `tensorboard` flag at `fit` method (thanks navidmdn )
- show warning messages when the unused arguments are passed
- show comprehensive error messages when action-space is not compatible
- `fit` method accepts `MDPDataset` object
- `dropout` option has been implemented in encoders
- add appropriate `__repr__` methods to show pretty outputs when `print(algo)`
- metrics collection is refactored
bugfix
- fix `core dumped` errors by fixing numpy version
- fix CQL backup