Mlx-kan

Latest version: v0.2.0

Safety actively analyzes 681812 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

0.2.0

This Release v0.2.0 covers big changes!!!

You can now easily train, save and load your KAN model easily with you r own data. This is currently only supported with the basic KAN model, more support for different model and architectures will come later, also more training algorythms, currently only the SimpleTrainer.

Here is an example using the mnist:

python
Import necessary libraries
import mlx.core as mx Import mlx core module for array manipulation

Import KAN model and arguments
from mlx_kan.kan import KAN
from mlx_kan.kan.args import ModelArgs

Import SimpleTrainer and training arguments
from mlx_kan.trainer.simpletrainer import SimpleTrainer
from mlx_kan.trainer.trainer_args import TrainArgs

Import MNIST dataset loader
import mlx_kan.quick_scripts.mnist as mnist

Define the model parameters
num_layers = 2 Number of layers in the model
in_features = 28 Input feature dimension (e.g., image width for MNIST)
out_features = 28 Output feature dimension (e.g., image height for MNIST)
hidden_dim = 64 Dimension of hidden layers
num_classes = 10 Number of output classes for classification (e.g., digits 0-9)

Initialize the KAN model with specified architecture
kan_model = KAN(
layers_hidden=[in_features * out_features] + [hidden_dim] * (num_layers - 1) + [num_classes], Define layers: input layer, hidden layers, output layer
args=ModelArgs Pass model arguments
)

Load the MNIST dataset
train_images, train_labels, test_images, test_labels = map(mx.array, getattr(mnist, "mnist")()) Convert dataset to mlx arrays

Set training arguments
TrainArgs.max_steps = 1000 Maximum number of training steps

Initialize and run the SimpleTrainer
SimpleTrainer(
model=kan_model, Model to be trained
args=TrainArgs, Training arguments
train_set=(train_images, train_labels), Training dataset
validation_set=(test_images, test_labels), Validation dataset (using test set for validation)
test_set=(test_images, test_labels), Testing dataset
validation_interval=1000, Interval for validation
logging_interval=10 Interval for logging
)


More examples are in the folder `examples`

0.1.71

0.1.9

Adding architectures into the mix.

Added architectures

`SmallKANMLP` Class

The `SmallKANMLP` class consists of two `KANLinear` layers. It is designed for small-scale models.

`MiddleKANMLP` Class

The `MiddleKANMLP` class consists of three `KANLinear` layers. It is designed for medium-scale models.

`BigKANMLP` Class

The `BigKANMLP` class consists of four `KANLinear` layers. It is designed for large-scale models.

`LlamaKANMLP` Class

The `LlamaKANMLP` class consists of three `KANLinear` layers configured in a the same manner Llama's MLP layer is configured. It is designed for models requiring a unique layer arrangement.


How to access?

python
from mlx_kan.kan.architectures.KANMLP import LlamaKANMLP, SmallKANMLP, MiddleKANMLP, BigKANMLP

0.1.8

This release optimizes the model and training code.

Added:
- progress bar for the steps in every epoch
- gradient clipping in the quick training code
- more clean up's

0.1.7

This is the first Package release:

install the Package:

sh
pip install mlx-kan


Example usage in Python:

python
from kan_mlx.kan import KAN

Initialize and use KAN
kan_model = KAN()


If you want some quck playing around:

sh
python -m mlx-kan.quick_scripts.quick_train --help

Links

Releases

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.