The weights in a study on fixing the saturated function
The weights are what constitute a neural network. The values the weights assume are a crucial aspect of representation learning. Let's implement line by line discussing each aspect for the learning.
The problem we are addressing here is tanh saturation. When the pre-activation values — the inputs to tanh — are too large in magnitude, the output gets pinned at or . In those flat regions, the gradient is essentially zero, and the neuron stops learning.
Let's build it step by step.
The setup.
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
We start with a character-level language model. The embedding matrix C maps each of the 27 characters (26 letters + the boundary token) into a 10-dimensional vector. The hidden layer has 200 neurons. The context window is 3 characters.
n_embd = 10
n_hidden = 200
block_size = 3
vocab_size = 27
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)
The problematic weights.
Now, the weight matrix for the hidden layer. This is where the problem starts:
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g)
b1 = torch.randn(n_hidden, generator=g) * 0.01
W1 has shape (30, 200). Each entry is drawn from . The input to this layer — the concatenated embedding — has 30 dimensions. When we compute the pre-activation embcat @ W1 + b1, we are summing 30 random products. By the central limit theorem, the output will have standard deviation .
That means the pre-activation values will typically range from to . And , which is as saturated as it gets.
# Forward pass.
emb = C[Xb] # (batch, 3, 10).
embcat = emb.view(emb.shape[0], -1) # (batch, 30).
hpreact = embcat @ W1 + b1 # (batch, 200).
Let's look at the pre-activation values.
print(hpreact.min().item(), hpreact.max().item())
# Something like -18.3, 17.9.
These values are enormous. Now apply tanh.
h = torch.tanh(hpreact) # (batch, 200) — activation
And inspect the result.
print(h.mean().item(), h.std().item())
# mean 0.0, std 0.99.
A standard deviation of 0.99 for a function bounded between and means almost everything is at the extremes. The histogram would show two spikes at and , with almost nothing in between.
Why this kills learning.
The gradient of tanh is.
When , the gradient is . During backpropagation, the gradient flowing through a saturated neuron is multiplied by this near-zero factor. The signal dies. The weight connected to that neuron receives essentially no gradient, so it does not update. The neuron is effectively dead — not permanently, but stuck until some other change in the network happens to push its pre-activation back into the active region.
# The gradient factor at each neuron
grad_factor = 1 - h**2
print(f"Fraction with gradient < 0.01: {(grad_factor < 0.01).float().mean().item():.2%}")
# Something like 89%.
89% of the neurons have gradients smaller than 0.01. Only 12% of the network is actually learning. This is catastrophic for a deep network — if this happens at every layer, the compound effect is exponential.
The fix, scale the weights.
The solution is to reduce the magnitude of W1 so that the pre-activation stays in the linear region of tanh (roughly to ):
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * 0.2
Multiplying by 0.2 shrinks the weight values. Now the pre-activation standard deviation drops from to , and the tanh output uses the full range without saturating.
But 0.2 is a magic number. We found it by trial and error. What we really want is a principled rule.
The principled fix, Kaiming initialization.
The variance analysis tells us that if the input has dimensions and each weight has variance , then the output has variance . To keep the output variance equal to the input variance (), we need.
.
For tanh, we also need a gain factor because tanh is contractive — it reduces the standard deviation of its input. PyTorch uses a gain of for tanh.
.
In code.
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / (n_embd * block_size)**0.5
With n_embd * block_size = 30, this gives . The pre-activation values now have standard deviation close to 1. The tanh output has a healthy spread — some saturation (around 5%), which is normal and even desirable, but not the 89% we had before.
The output layer.
The output layer needs a different treatment.
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0
We scale W2 by 0.01, which makes the initial logits very small. Small logits mean the softmax output is close to a uniform distribution — for each character. This is exactly what we want at initialization: the network starts with no confidence, and the initial loss is close to the theoretical minimum of .
Without this scaling, the initial logits can be large, producing a confident but wrong distribution. The loss starts high (we see values like 27 instead of 3.3 — the "hockey stick" in the loss curve), and the first many training steps are wasted just learning to be less confident.
b2 = torch.randn(vocab_size, generator=g) * 0
The bias b2 is set to zero. There is no reason for the network to prefer any character over another before it has seen the data.
Putting it together.
# Good initialization
C = torch.randn((vocab_size, n_embd), generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / (n_embd * block_size)**0.5
b1 = torch.randn(n_hidden, generator=g) * 0.01
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0
parameters = [C, W1, b1, W2, b2]
for p in parameters:
p.requires_grad = True
Every weight has a deliberate scale. Nothing is left to chance. The embedding C is standard normal (it will be learned). W1 is Kaiming-scaled for tanh. b1 is small. W2 is small to keep logits near zero. b2 is zero.
The training loop is then straightforward.
for i in range(200000):
ix = torch.randint(0, Xtr.shape[0], (32,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yb)
for p in parameters:
p.grad = None
loss.backward()
lr = 0.1 if i < 100000 else 0.01
for p in parameters:
p.data += -lr * p.grad
Each line in the forward pass is a direct consequence of the choices above: embed, concatenate, linear transform, normalize through tanh, project to logits, compute loss. And each weight was initialized with a specific purpose: to keep the signal flowing through the network at a healthy magnitude from step zero.
The lesson is simple. The weights are not just numbers to optimize — their initial values determine whether the optimization can even begin. A neural network with saturated activations is a network that cannot hear its own gradients. Fixing the saturation is not an optimization trick; it is a prerequisite for learning to happen at all.