Understanding Neural Networks with Geometry

A study through Geometry of how a simple neural network learns modular arithmetic from data.

Author: Sarthak Bagaria
Written: 21 Dec 2025

Neural Networks are taking over the world. They are being used in education, healthcare, robotics, scientific research and in almost every other field. Yet, for most part they stubbornly remain black boxes. We know they work, but we often don't know how they arrive at their answers. This isn't just an academic problem; it is a core hurdle for AI Alignment. If we can't interpret the internal logic of a model, how can we ensure it stays aligned with human values as it scales?

In this post, we are going to crack open a small network and look at its brain through the lens of Geometry. We'll see how the model learns to solve a math problem by literally building a circle in its mind.

The Experiment: Modular Addition

We are going to train a simple neural network to solve a modular addition problem:
.

Embedding: To begin with, we one-hot encode each of our two input numbers to a 67-dimension vector, which means that we map each number i to a vector with ith component as 1 and other components 0 - so each input is represented as an independent vector. We then perform a learnable linear transformation of the one-hot encoding using a (67,128) matrix, so our final embedding is a 128-dimension vector for each input.

In the first approach, we stack (concatenate) the embeddings of the inputs to get a dimensional vector which is passed to rest of the network. Let's see how this model performs:

Did you notice the circles formed by the Principal Components of the embeddings when the model starts to generalize and achieve high accuracy on unseen cases? This is the model learning the symmetry of the problem. Let's delve a bit into this symmetry.

Reader of Maths would recogize that for a prime number p, is a cyclic group of order p and an irreducible representation of this group over complex numbers is

and then modular addition is just the multiplication of the two representations

For others, the part that is important for us here is that we can represent each number k as a 2-dimensional vector:

and then addition by one is just rotation of this vector by . Notice that after rotating p times, the vector comes back to it's original position, which is consistent with modular addition. This mapping has an additional nice property of continuity which neural networks tend to like (the values don't jump abrupty around p).

Our plots show that the neural network has learned this 2-dimensional representation of the inputs from the symmetry of the problem just by training on the dataset for several iterations. By exploiting this symmetry it was able to generalize well.

You might wonder: why does the model bother learning this elegant geometric symmetry long after it has already achieved 100% training accuracy? This phenomenon is often called "Grokking". It suggests that even after the model finds a way to solve the problem, it continues to compress its internal logic. A compelling explanation for this comes from Singular Learning Theory (SLT).

The Mathematics of Grokking

In classical statistics, we often assume models are regular (the mapping between parameters and functions is one-to-one). However, neural networks are singular: many different sets of weights can result in the exact same mathematical function. This creates a complex loss landscape where the minima aren't just points, but interconnected valleys and ridges.

In SLT, we describe the preference for a model state not just by its error, but by its Marginal Likelihood . This represents the total volume of parameter space that produces the same low-error behavior.

Sumio Watanabe proved that for a large number of samples , the Log-Marginal Likelihood (often called the Free Energy) follows a specific asymptotic expansion:

Where:

Think of as a measure of pointiness. A large means the solution is a very sharp, narrow needle in the parameter space. A small means the solution is a wide, flat valley.

The volume of parameters that achieve an error scales as:

Even though the model could solve modular arithmetic by memorizing every single pair (a high-complexity solution with a large ), that solution is mathematically thin. The circular representation is highly structured and symmetric, which makes it simpler in the eyes of SLT, resulting in a smaller .

Because the low- region (the circle) has a much larger functional volume, stochastic optimizers like SGD are statistically sucked into these wider valleys over time. This is the transition we see in the Principal Component Analysis (PCA) plots: the model is moving from a narrow, complex memorization peak to the broad, simple geometric valley of the circle.

Assisting The Model With Symmetry

We saw above that the model learnt the symmetry of the problem from the data and used it to generalize. But is there a way we can encode symmetry in the model architecture directly so it doesn't have to learn from scratch? Before we embark on this, I should point out that even though we understand the symmetries quite well in our toy example here, it is usually quite difficult to figure out symmetries in the kind of data that data scientists usually deal with.

Instead of encoding all the symmetries of our problem, which will make the neural network a bit redundant, let's start with a symmetry which is indeed present in several datasets - invariance with respect to permutation of arguments.

To embed this symmetry in the model architecture, we will now add the two 128-dimensional embedding vectors of the inputs, rather than stacking them as in our previous approach. So the input to the next layer will also change now from 256 dimensions to 128 dimensions.

Let's first see the results of this approach:

We see that the model generalized much faster this time and learned the circular representation of the inputs again.

Those who have read the Deep Sets paper will notice that, in general, incorporating the permutation symmetry requires the embeddings to be passed through an additional Multi-Layer Perceptron before they are added but I have not done that. I have instead sneakily exploited another neat property of our circular embeddings:

so the addition of the embeddings gives a vector with angle equal to the sum of the input vector angles, which is all the MLP needs to read and label.

Extending to Multiplication

What happens if we want to train the model to learn modular multiplication instead of addition?

This looks very much like the chart for modular addition, with the model learning a circular representation of the inputs. You may think "Okay, maybe the circular representations are still good as the network can still learn to multiply using them but the addition symmetry would surely break." Let's try with our architecture that adds the circular embeddings.

Again our model is working as well as in the case of modular addition. Let's see why this is happening. The key here is group isomorphism. Readers of Maths will notice that we now have a cyclic group of order p-1, . For others, let's construct a table of

kk
123465
243563
383659
4163751
5323835
664393
761406
8554112
9434224
10194348
11384429
1294558
13184649
14364731
1554862
16104957
17205047
18405127
19135254
20265341
21525415
22375530
2375660
24145753
25285839
26565911
27456022
28236144
29466221
30256342
31506417
32336534
3366661

We see that multiplication by 2 can generate all elements from 1 to 66. The embedding can learn to map this multiplication problem to the addition problem like this:

So the circlular representation with the symmetry of addition still works. The neural network will learn to re-label the output to go back from additive group to the multiplicative group.

Apart from reduction to 66 vectors from 67 vectors in the circular representation, keen observers might have noticed clusters forming in the circular representation of the multiplicative group at certain points during model training. This is a consequence of the structure of a divisible group. For non-prime p, the cyclic p-group contains subgroups of smaller size, specifically cyclic k-groups where k divides k - 2, 3, 6, 11, 22, 33 in our case. We may see these many clusters form in our representations as the model explores the symmetries of the problem. An example of a subgroup is

because

Further

This blog post is inspired by the exciting developments in Singular Learning Theory and Geometric Deep Learning.

The code for the analysis is available here.