A deep network like in PyTorch
The progress now will be in two perspectives. Now, we are going to transform the code for the deep network definition and training loop into a structure that resembles PyTorch style. So, the first perspective is about gaining familiarity with the PyTorch style, that is the language that the AI community understands and implements, but still with the attention to the implementation. The second perspective is in terms of the network being deeper. So, our MLP in previous posts was just one layer. Now, we are going to handle a deeper network. In the previous implementations, the network had a total of 3k parameters, now, the network is going to have close to 47k parameters.
Let's go!
The imports are the same.
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
Simply PyTorch as the core and matplotlib for visualizations.
The next step is the neural net structure. Instead of having raw weight matrices and manual operations scattered across the training loop, we encapsulate each operation into a class. Each class has a __call__ method for the forward pass and a parameters method that returns the trainable tensors. This is exactly how torch.nn.Module works — we are building our own minimal version.
class Linear:
def __init__(self, fan_in, fan_out, bias=True):
self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5
self.bias = torch.zeros(fan_out) if bias else None
def __call__(self, x):
self.out = x @ self.weight
if self.bias is not None:
self.out += self.bias
return self.out
def parameters(self):
return [self.weight] + ([] if self.bias is None else [self.bias])
The Linear class already incorporates the scaling at initialization. This is the baseline from which Kaiming adds the gain factor. Notice self.out — we store the output so that we can inspect the activation statistics later. This is a diagnostic pattern, not something you would normally do in production.
class BatchNorm1d:
def __init__(self, dim, eps=1e-5, momentum=0.1):
self.eps = eps
self.momentum = momentum
self.training = True
# Trainable parameters
self.gamma = torch.ones(dim)
self.beta = torch.zeros(dim)
# Running statistics (not trained by backprop)
self.running_mean = torch.zeros(dim)
self.running_var = torch.ones(dim)
def __call__(self, x):
if self.training:
xmean = x.mean(0, keepdim=True)
xvar = x.var(0, keepdim=True)
else:
xmean = self.running_mean
xvar = self.running_var
xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
self.out = self.gamma * xhat + self.beta
# Update running stats
if self.training:
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
return self.out
def parameters(self):
return [self.gamma, self.beta]
The BatchNorm1d class is the most complex. It has two modes: training (uses batch statistics) and inference (uses running statistics). The gamma and beta are the learnable scale and shift — they allow the network to undo the normalization if that is what the data requires. The running mean and variance are updated with exponential moving averages during training, so that at inference time we have stable statistics that do not depend on a batch.
class Tanh:
def __call__(self, x):
self.out = torch.tanh(x)
return self.out
def parameters(self):
return []
The Tanh class is trivially simple. But wrapping it in a class means it can sit in a list alongside Linear and BatchNorm1d, which makes the architecture composable.
Now the deep network definition becomes clean:
n_embd = 10
n_hidden = 100
C = torch.randn((vocab_size, n_embd))
layers = [
Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
Linear( n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size),
]
Five hidden layers of 100 neurons, each followed by batch normalization and tanh. The output layer has its own BatchNorm1d but no Tanh — because the logits go directly into the cross-entropy loss.
Two initialization details:
with torch.no_grad():
layers[-1].gamma *= 0.1 # Make the output layer less confident.
for layer in layers[:-1]:
if isinstance(layer, Linear):
layer.weight *= 5/3 # Kaiming gain for tanh.
The last layer's gamma is scaled down to 0.1, which means the initial logits will be small and the predicted probabilities will be close to uniform — exactly what we want before the network has learned anything. The hidden layer weights get the gain to compensate for tanh contraction.
The parameter collection and the training loop follow the same composable pattern:
parameters = [C] + [p for layer in layers for p in layer.parameters()]
for p in parameters:
p.requires_grad = True
for i in range(max_steps):
# Mini-batch.
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix]
# Forward pass.
emb = C[Xb]
x = emb.view(emb.shape[0], -1)
for layer in layers:
x = layer(x)
loss = F.cross_entropy(x, Yb)
# Backward pass.
for p in parameters:
p.grad = None
loss.backward()
# Update.
lr = 0.1 if i < 100000 else 0.01
for p in parameters:
p.data += -lr * p.grad
The forward pass is now a for loop over layers. This is the key structural change — the network is a sequence of composable operations, not a monolithic block of matrix multiplications. Adding a layer, removing batch normalization, swapping Tanh for ReLU — all of these become single-line changes instead of rewiring the entire training loop.
This is not just an aesthetic improvement. The composability means that the diagnostic visualizations from the previous posts (activation histograms, gradient statistics, update ratios) can now be computed generically over layer.out for any layer in the list, regardless of what that layer does internally. The architecture and the diagnostics are decoupled.
The total parameter count comes to about 47k — roughly 15 times our original single-layer network.