TorchKAN introduces a simplified KAN model and its variations, including KANvolver and KAL-Net, designed for high-performance image classification by leveraging polynomial transformations for enhanced feature detection.
nKAN.py
script as a general structure.This project showcases the training, validation, and quantization of the KAN model using PyTorch with CUDA acceleration. The torchkan
model is evaluated on the MNIST dataset, demonstrating significant accuracy improvements.
The KAN model has demonstrated promising outcomes across various Generative Additive Models (GAMs) since the 1980s. Inspired by a range of sources, this initial implementation of KAN
in torchkan.py
achieves over 97% accuracy with an evaluation time of 0.6 seconds. The quantized model further reduces this time to under 0.55 seconds on the MNIST dataset within 8 epochs, utilizing an Nvidia RTX 4090 on Ubuntu 22.04.
Current Understanding: While there is considerable hype around KANs, it’s important to recognize that learning weights for activation functions (MLPs) and the activation functions themselves are established ideas. The extent of interpretability, scalability, quantizability, or efficiency they offer remains unclear. However, quantizability does not seem to be an issue, as the quantized evaluation on the base model leads to only ~0.6% drop in test performance.
Note: As the model is still under research, further exploration into its full potential is ongoing. Contributions, questions, and critiques are welcome. Constructive feedback and contributions are appreciated, and merge requests will be processed promptly, with a clear outline of the issue, the solution, and its effectiveness.
Note: The PyPI pipeline is currently deprecated and will be stabilized following the release of Version 1.
The KANvolver
model is a specialized neural network designed for classifying images from the MNIST dataset. It achieves an accuracy of ~99.56% with a minimal error rate of 0.18%. This model combines convolutional neural networks (CNNs) with polynomial feature expansions, effectively capturing both simple and complex patterns.
I am conducting large-scale analysis to investigate how KANs can be made more interpretable.
Thanks to @cometscome for writing this version in Julia: https://github.com/cometscome/FluxKAN.jl
Convolutional Feature Extraction: The model begins with two convolutional layers, each paired with ReLU activation and max-pooling. The first layer employs 16 filters of size 3x3, while the second increases the feature maps to 32 channels.
Polynomial Feature Transformation: After feature extraction, the model applies polynomial transformations up to the n-th order to the flattened convolutional outputs, enhancing its ability to discern non-linear relationships.
How Monomials Work: In this model, monomials are polynomial powers of the input features. By computing monomials up to a specified order, the model captures non-linear interactions between the features, potentially leading to richer and more informative representations for downstream tasks.
For a given input image, the monomials of its flattened pixel values are computed and then used to adjust the output of linear layers before activation. This approach introduces an additional dimension of feature interaction, allowing the network to learn more complex patterns in the data.
The KANvolver
model’s 99.5% accuracy on MNIST underscores its robustness in leveraging CNNs and polynomial expansions for effective digit classification. While it shows significant potential, the model remains open for further adaptation and exploration in broader image processing challenges. Here are the results:
Note that KANvolver uses polynomials that are distinct from the original KANs[1].
KANs seem to handle noise better compared to MLPs for functional approximation. This requires further investigation.
To reproduce the results, use the nKAN.py
script.
The KAL_Net
represents the Kolmogorov Arnold Legendre Network (KAL-Net), a GAM architecture using Legendre polynomials to surpass traditional polynomial approximations like splines in KANs.
functools.lru_cache
, the network avoids redundant computations, enhancing the forward pass’s speed.KAL_Net
achieved a remarkable 97.8% accuracy on the MNIST dataset, showcasing its ability to handle complex patterns in image data.Ensure the following are installed on your system:
Tested on MacOS and Linux.
Clone the torchkan
repository and set up the project environment:
git clone https://github.com/1ssb/torchkan.git
cd torchkan
pip install -r requirements.txt
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
To monitor experiments and model performance with wandb:
Before running the training script, initialize wandb:
wandb login
Enter your API key when prompted to link your script executions to your wandb account.
mnist.py
to Your Username (default is 1ssb
)python mnist.py
This script trains the model, validates it, quantizes it, and logs performance metrics using wandb.
For inquiries or support, please contact: Subhransu.Bhattacharjee@anu.edu.au
If this project is used in your research or referenced for baseline results, please use the following BibTeX entry.
@misc{torchkan,
author = {Subhransu S. Bhattacharjee},
title = {TorchKAN: Simplified KAN Model with Variations},
year = {2024},
howpublished = {\url{https://github.com/1ssb/torchkan/}}
}
Contributions are welcome. Please raise issues as needed. Maintained solely by @1ssb.