This Release v0.2.0 covers big changes!!!
You can now easily train, save and load your KAN model easily with you r own data. This is currently only supported with the basic KAN model, more support for different model and architectures will come later, also more training algorythms, currently only the SimpleTrainer.
Here is an example using the mnist:
python
Import necessary libraries
import mlx.core as mx Import mlx core module for array manipulation
Import KAN model and arguments
from mlx_kan.kan import KAN
from mlx_kan.kan.args import ModelArgs
Import SimpleTrainer and training arguments
from mlx_kan.trainer.simpletrainer import SimpleTrainer
from mlx_kan.trainer.trainer_args import TrainArgs
Import MNIST dataset loader
import mlx_kan.quick_scripts.mnist as mnist
Define the model parameters
num_layers = 2 Number of layers in the model
in_features = 28 Input feature dimension (e.g., image width for MNIST)
out_features = 28 Output feature dimension (e.g., image height for MNIST)
hidden_dim = 64 Dimension of hidden layers
num_classes = 10 Number of output classes for classification (e.g., digits 0-9)
Initialize the KAN model with specified architecture
kan_model = KAN(
layers_hidden=[in_features * out_features] + [hidden_dim] * (num_layers - 1) + [num_classes], Define layers: input layer, hidden layers, output layer
args=ModelArgs Pass model arguments
)
Load the MNIST dataset
train_images, train_labels, test_images, test_labels = map(mx.array, getattr(mnist, "mnist")()) Convert dataset to mlx arrays
Set training arguments
TrainArgs.max_steps = 1000 Maximum number of training steps
Initialize and run the SimpleTrainer
SimpleTrainer(
model=kan_model, Model to be trained
args=TrainArgs, Training arguments
train_set=(train_images, train_labels), Training dataset
validation_set=(test_images, test_labels), Validation dataset (using test set for validation)
test_set=(test_images, test_labels), Testing dataset
validation_interval=1000, Interval for validation
logging_interval=10 Interval for logging
)
More examples are in the folder `examples`