Just as Multi-Layer Perceptrons (MLPs) are based on the universal approximation theorem, in April 2024 a groundbreaking paper was published introducing a new architecture called KAN (Kolmogorov-Arnold Networks), inspired by and based on the Kolmogorov-Arnold representation theorem.
Theoretical Foundation
These two fundamental theorems can be compared as follows:
- Universal Approximation Theorem: States that any function between compact sets can be approximated by a neural network with a finite number of neurons.
- Kolmogorov-Arnold Theorem: States that any continuous multivariate function can be approximated as a finite composition of continuous univariate functions and the binary operation of addition. Specifically:
Where:
$$\phi_{q, p}: [0, 1] \longrightarrow \mathbb{R} \text{ and } \Phi_q: \mathbb{R} \longrightarrow \mathbb{R}$$In concrete terms, the implementation of this paradigm uses splines or other function approximators instead of the theoretical $\phi$ and $\Phi$ functions.
Key Advantages Over MLPs
- Parameter Efficiency: Requires fewer parameters to achieve the same results compared to MLPs
- Interpretability: Highly interpretable results, unlike black-box MLPs
- Extractable Formulas: Possible to extract practically concise formulas with non-linear elements. MLPs have too many terms for this to be practical
Parameter Count Comparison
Consider the following setup:
- $L$ = depth of both networks
- $N$ = number of weights per neuron
- $G$ = number of grid points dividing the interval $[0, 1]$
Assuming two networks (MLP and KAN) with $L$ layers and equal width in each layer: $n_0 = n_1 = \ldots = n_L = N$
- An MLP has a total of $O(N^2L)$ parameters
- A KAN has a total of $O(N^2LG)$ parameters (when using B-splines, since each entry in the function matrix has $G$ weights, one per grid element)
Technical Implementation
In an MLP, we have several layers, and in each layer we have an associated weight matrix that encodes the relationships between layer $L$ and layer $L-1$ values, after applying the corresponding activation function:
$$a^{(L)} = \sigma(W_{n_L, n_{L-1}} a^{(L-1)})$$In comparison, we can define a KAN layer as a matrix of one-dimensional functions:
$$\Phi = \{\phi_{q, p}\}, \quad p = 1, 2, \ldots, n_{in}, \quad q = 1, 2, \ldots, n_{out}$$Where these functions have trainable parameters (splines). The inner functions form a KAN layer of dimension $(n_{in}, n_{out}) = (n, 2n+1)$ while the outer functions form a layer of dimension $(n_{in}, n_{out}) = (2n+1, 1)$.
Between layers $l$ and $l+1$, there exist $n_l \times n_{l+1}$ functions.
Matrix Form
The matrix form can be written as:
$$ x_{l+1} = \begin{pmatrix} \phi_{l,1,1}(\cdot) & \phi_{l,1,2}(\cdot) & \cdots & \phi_{l,1,n_l}(\cdot) \\ \phi_{l,2,1}(\cdot) & \phi_{l,2,2}(\cdot) & \cdots & \phi_{l,2,n_l}(\cdot) \\ \vdots & \vdots & \ddots & \vdots \\ \phi_{l,n_{l+1},1}(\cdot) & \phi_{l,n_{l+1},2}(\cdot) & \cdots & \phi_{l,n_{l+1},n_l}(\cdot) \end{pmatrix} x_l $$Which is equivalent to:
$$x_{l+1} = \Phi_l x_{l}$$where $\Phi_l$ is the function matrix corresponding to layer $l$.
A general KAN is the composition of $L$ such layers. For a vector $x_0 \in \mathbb{R}^{n_0}$, the output of a KAN is:
$$ \text{KAN}(x) = (\Phi_{L-1} \circ \Phi_{L-2} \circ \cdots \circ \Phi_1 \circ \Phi_0)x $$Approximation Theorem
Theorem (Approximation Theory):
Let $x = (x_1, x_2, \dots, x_n)$. Suppose that $f(x)$ admits a representation
$$ f = (\Phi_{L-1} \circ \Phi_{L-2} \circ \cdots \circ \Phi_1 \circ \Phi_0)x $$where each $\Phi_{l,i,j}$ is $(k+1)$-times continuously differentiable. Then there exists a constant $C$ dependent on $f$ and its representation, and there exist $k$ B-spline functions $\Phi_{l,i,j}^G$ such that for each $0 \leq m \leq k$, we have the bound:
$$ \|f - (\Phi_{L-1}^G \circ \Phi_{L-2}^G \circ \cdots \circ \Phi_1^G \circ \Phi_0^G)x\|_{C^m} \leq C G^{-k-1+m} $$Where the norm is defined as:
$$\|g\|_{C^m} = \max_{|\beta| \leq m} \sup_{x \in [0, 1]^n} |D^{\beta} g(x)|$$Conclusion: We can approximate the representative function of $f$ by modifying the grid (the finer the grid, the better the approximation).
Implementation with Taylor Polynomials
Instead of using a grid to estimate splines with B-splines, we'll approximate the splines using Taylor expansions. This theorem has an analogous implication when we better approximate the splines separately.
KAN Layer Definition
As can be observed in the paper, each layer consists of a matrix of functions, where the "matrix multiplication" is defined as the substitution of values in each function. Each $\phi$ is defined as follows:
A base function $b(x)$ is included such that the function $\phi(x)$ is the sum of this and the spline, and the implementation is finalized by multiplying by learnable weights. This way, contributions can be centralized in specific layers:
$$ \phi(x) = w_b b(x) + w_s \text{spline}(x) $$Where the base function is:
$$ b(x) = \text{silu}(x) = \frac{x}{1 + e^{-x}} $$PyTorch Implementation
Here's the implementation of a KAN layer using Taylor polynomial expansion:
import torch
from torch import nn
def activacion_base(x):
return x * torch.sigmoid(x)
class capaKAN_taylor(nn.Module):
def __init__(self, input_dim, output_dim, grado_taylor):
super(capaKAN_taylor, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Initialize parameters close to 0 to prevent large errors
self.pesos_funcion = nn.Parameter(
0.001 * torch.randn(output_dim, input_dim),
requires_grad=True
)
self.pesos_base = nn.Parameter(
0.001 * torch.randn(output_dim, input_dim),
requires_grad=True
)
self.pesos_taylor = nn.Parameter(
torch.randn(output_dim, grado_taylor, input_dim) * 0.001,
requires_grad=True
)
self.grado_taylor = grado_taylor
def forward(self, x):
if len(x.shape) == 1:
x = x.unsqueeze(1)
# Compute Taylor polynomial terms
exponents = torch.concat(
[x.unsqueeze(-1)**j for j in range(self.grado_taylor)],
axis=-1
)
taylor_terms = torch.einsum('njg,igj->inj', exponents, self.pesos_taylor)
base = activacion_base(x).unsqueeze(0)
# Combine base function and Taylor terms
base_function_adding = base + taylor_terms
self.transformed = torch.einsum(
'ij,inj->ni',
self.pesos_funcion,
base_function_adding
)
return self.transformed
Complete KAN Network
class KAN(nn.Module):
def __init__(self, lista_estructura, grado_taylor):
super(KAN, self).__init__()
self.layers = nn.ModuleList([])
for i in range(len(lista_estructura)-1):
self.layers.append(
capaKAN_taylor(
lista_estructura[i],
lista_estructura[i+1],
grado_taylor
)
)
def forward(self, x):
for layer in self.layers:
x = layer.forward(x)
return x
Experimental Results
The KAN was trained on a toy dataset to approximate the function $z = x^2 + y^2$. The network structure was $[2, 4, 4, 1]$ with Taylor polynomial degree 10, trained for 200 epochs using the Adam optimizer with cosine annealing learning rate schedule.
The results demonstrated successful function approximation with mean squared error consistently decreasing throughout training. The gradient norms for each layer confirmed convergence to a local minimum.
Applications and Future Directions
This type of structure is extremely useful for discovering formulas in physical experiments, as it's sufficient to examine the trained splines to understand how the final function is constructed from the parameters. Additionally, this type of estimator is highly interpretable—it's only necessary to follow the spline evaluations to understand how each variable affects the final result.
Key application areas include:
- Physics: Discovering functional relationships in experimental data
- Scientific Computing: Interpretable function approximation
- Symbolic Regression: Extracting mathematical formulas from data
- Domain Science: Understanding complex multivariate relationships
Conclusion
Kolmogorov-Arnold Networks represent a paradigm shift in neural network architectures, offering superior interpretability and parameter efficiency compared to traditional MLPs. The implementation using Taylor polynomial expansions provides a practical and flexible approach to function approximation while maintaining the theoretical guarantees of the Kolmogorov-Arnold representation theorem.
The code and experiments demonstrate that KANs can successfully learn complex multivariate functions with fewer parameters and greater transparency than conventional deep learning approaches.