Skip to content
Sunbeam on GitHub Sunbeam on Twitter

Simple, Beautiful Neural Networks

Truth is ever found in simplicity, and not in the multiplicity and confusion of things.

The spirit of Isaac Newton's words resonate with me, especially during a time of large language and vision models. With hundred billion or trillion parameter networks, simplicity is not something that comes to mind. In fact, the complexities of wrangling and researching such models can make you question how anyone ever thought of this. That's why I want to go back to the basics and illustrate the intuition of neural networks. I want look at them in a way where anyone feels like they could have invented the idea. Finally, I want to show the beauty of very small neural networks applied on trivial problems, because the truth is closer in their weights than in the trillion parameter weights of internet-scale models.

I'll start off with the simplest example: an artificial neural network with a single neuron. The problem is to model the quadratic function, $y=x^2$, so a dataset might look like $x = \begin{bmatrix} 1 & 2 \end{bmatrix}$ and $y=\begin{bmatrix} 1 & 4 \end{bmatrix}$. Since a single neuron can be defined as $\hat{y} = w \hat{x} + b$, we can rewrite this more specifically as

$$\begin{bmatrix} \hat{y_1} \end{bmatrix} = \begin{bmatrix} w_1 \end{bmatrix} \begin{bmatrix} \hat{x_1} \end{bmatrix} + \begin{bmatrix} b_{1} \end{bmatrix}.$$

We will use a popular activation function like ReLU to introduce nonlinear modeling to our otherwise linear function. In this case, $\sigma(n) = \max(0, n)$, so that we can rewrite this as a composite function.

$$z = \sigma(\hat{y}) = \max(0, w \hat{x} + b).$$

To learn from examples, we must do a forward-pass to predict with the current weights of our neural network and see how bad the prediction was from the examples. Then, we do a backward-pass to update the weights of our model to close the gap between the prediction and the example. The last piece we need is the loss function, which gives us a measure of our predictions relative to the example. In this case, we will use mean-squared error ($MSE$), where

$$MSE = \frac{1}{N} \sum_{i=0}^{N} (z - y)^2.$$

If we randomly initialize the weights, taking $w_1 = 3$ and $b_1 = 7$, then we can get our first prediction with a forward-pass. Evaluating,

$$z = \max(0, 3(2)+7) = 13.$$

As far as predictions go, getting $13$ when we expect $4$ is pretty bad. Plugging into $MSE$ confirms as much.

$$MSE = (13 - 4)^2 = 81.$$

Now, we need to perform a backward-pass to update the model and improve the prediction. We will use gradient descent to update the weight of the model to minimize the $MSE$. To do so, backpropagation will be used to calculate the error derivatives, so that the error from the prediction or output can be propagated back through the neural network to the input. For this example, demonstrating the backward-pass is easy because there are only two parameters in the network. We will use gradient descent to update the weight and bias parameters. For clarity, we'll define the loss function as $MSE = L$. The idea is to wiggle the parameters in such a way as to lower the value of the loss function when evaluating examples. To do so, we need to know the derivative of the loss function with respect to the weight and bias. Assuming $wx + b > 0$, then

$$ \frac{\partial L}{\partial w} = (wx + b - y)^2 = 2wx$$ and $$ \frac{\partial L}{\partial b} = (wx + b - y)^2 = 2b.$$

The intuition of gradient descent is that if we know the rate of change of the loss function with respect to the weight and bias parameters, then we can minimize the loss function by going in the negative direction of the derivative. To find the new weight and bias, we use the gradient descent formula as

$$ w_{n+1} = w_n - \gamma \frac{\partial L}{\partial w_n} $$ and $$ b_{n+1} = b_n - \gamma \frac{\partial L}{\partial b_n}.$$

The $\gamma$ term is an arbitray stepsize that you can choose to modulate the effect of the gradient on the parameter update. In our case, we go with $\gamma = 1$. Plugging everything in, we get

$$w_2 = w_1 - 2 w_1 x = 3 - 2(3)(2) = -9$$ and $$b_2 = b_1 - 2 b_1 = 7 - 2 (7) = -7.$$

If we re-evaluate again for $z$, we get

$$z = \max(0, (-9)(2)-7) = 0.$$

A prediction of 0 ends up being a much better guess than the initial prediction of 13 for $\hat{x_1} = 2$. Clearly, gradient descent worked, though, ReLU is doing a lot of heavy-lifting in this example. It's not necessarily that the new weight and bias parameters are a better representation of a quadratic function or the examples, but that the ReLU is filtering out negative values. In this case, arbitrarily large negative $w$ and $b$ leads to a better prediction than our initial positive, random parameters. I like this example because it showcases the gating effect that the activation functions can have. Additionally, you can imagine more easily how different nonlinear functions may gate or filter out information that is being passed through the neuron.

These steps are the overall recipe to training neural networks, finding a good configuration of parameters in weights and bias to minimize the error between a prediction and example. The next set of questions may be on how convolutional or transformer neural networks were conceived. Both architectures arose from the need to efficiently model specific types of data. In the case of convolutional neural networks, it was visual data that had spatial redundancies. For transformers, it was text data where each part of the input data had some dependence on other parts of the input data. I'll cover both convolutions and transformers in other blog posts.