What's Changed * Automatically determine num_actions and num_chance_outcomes in stochastic_muzero_policy (thanks Carlos Martin). * Explicitly use int32 for the argmax output even when using jax_enable_x64.
What's Changed * Automatically determine num_actions and num_chance_outcomes in stochastic_muzero_policy (thanks Carlos Martin). * Explicitly use int32 for the argmax output even when using jax_enable_x64.
What's Changed * Add a link to a0-jax on Connect Four, Gomoku, and Go. * Improve debugging by avoiding an unused NaN. * Use jax.tree_util to avoid deprecation warnings.