Breno

A deep network like in PyTorch

The progress now will be in two perspectives. Now, we are going to transform the code for the deep network definition and training loop into a structure that resembles PyTorch style. So, the first perspective is about gaining familiarity with the PyTorch style, that is the language that the AI community understands and implements, but still with the attention to the implementation. The second perspective is in terms of the network being deeper. So, our MLP in previous posts was just one layer. Now, we are going to handle a deeper network. In the previous implementations, the network had a total of 3k parameters, now, the network is going to have close to 47k parameters.

Let's go!

The imports are the same.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

Simply PyTorch as the core and matplotlib for visualizations.

The next step is the neural net structure. Instead of having raw weight matrices and manual operations scattered across the training loop, we encapsulate each operation into a class. Each class has a __call__ method for the forward pass and a parameters method that returns the trainable tensors. This is exactly how torch.nn.Module works — we are building our own minimal version.

class Linear:

    def __init__(self, fan_in, fan_out, bias=True):
        self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5
        self.bias = torch.zeros(fan_out) if bias else None

    def __call__(self, x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

The Linear class already incorporates the 1/nin scaling at initialization. This is the baseline from which Kaiming adds the gain factor. Notice self.out — we store the output so that we can inspect the activation statistics later. This is a diagnostic pattern, not something you would normally do in production.

class BatchNorm1d:

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # Trainable parameters
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        # Running statistics (not trained by backprop)
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self, x):
        if self.training:
            xmean = x.mean(0, keepdim=True)
            xvar = x.var(0, keepdim=True)
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        self.out = self.gamma * xhat + self.beta
        # Update running stats
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

The BatchNorm1d class is the most complex. It has two modes: training (uses batch statistics) and inference (uses running statistics). The gamma and beta are the learnable scale and shift — they allow the network to undo the normalization if that is what the data requires. The running mean and variance are updated with exponential moving averages during training, so that at inference time we have stable statistics that do not depend on a batch.

class Tanh:

    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

The Tanh class is trivially simple. But wrapping it in a class means it can sit in a list alongside Linear and BatchNorm1d, which makes the architecture composable.

Now the deep network definition becomes clean:

n_embd = 10
n_hidden = 100

C = torch.randn((vocab_size, n_embd))

layers = [
    Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(           n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(           n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(           n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(           n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(           n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size),
]

Five hidden layers of 100 neurons, each followed by batch normalization and tanh. The output layer has its own BatchNorm1d but no Tanh — because the logits go directly into the cross-entropy loss.

Two initialization details:

with torch.no_grad():
    layers[-1].gamma *= 0.1          # Make the output layer less confident.
    for layer in layers[:-1]:
        if isinstance(layer, Linear):
            layer.weight *= 5/3       # Kaiming gain for tanh.

The last layer's gamma is scaled down to 0.1, which means the initial logits will be small and the predicted probabilities will be close to uniform — exactly what we want before the network has learned anything. The hidden layer weights get the 5/3 gain to compensate for tanh contraction.

The parameter collection and the training loop follow the same composable pattern:

parameters = [C] + [p for layer in layers for p in layer.parameters()]
for p in parameters:
    p.requires_grad = True

for i in range(max_steps):
    # Mini-batch.
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))
    Xb, Yb = Xtr[ix], Ytr[ix]

    # Forward pass.
    emb = C[Xb]
    x = emb.view(emb.shape[0], -1)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, Yb)

    # Backward pass.
    for p in parameters:
        p.grad = None
    loss.backward()

    # Update.
    lr = 0.1 if i < 100000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

The forward pass is now a for loop over layers. This is the key structural change — the network is a sequence of composable operations, not a monolithic block of matrix multiplications. Adding a layer, removing batch normalization, swapping Tanh for ReLU — all of these become single-line changes instead of rewiring the entire training loop.

This is not just an aesthetic improvement. The composability means that the diagnostic visualizations from the previous posts (activation histograms, gradient statistics, update ratios) can now be computed generically over layer.out for any layer in the list, regardless of what that layer does internally. The architecture and the diagnostics are decoupled.

The total parameter count comes to about 47k — roughly 15 times our original single-layer network.