Breno

The weights in a study on fixing the saturated function

The weights are what constitute a neural network. The values the weights assume are a crucial aspect of representation learning. Let's implement line by line discussing each aspect for the learning.

The problem we are addressing here is tanh saturation. When the pre-activation values — the inputs to tanh — are too large in magnitude, the output gets pinned at 1 or +1. In those flat regions, the gradient is essentially zero, and the neuron stops learning.

Let's build it step by step.

The setup.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

We start with a character-level language model. The embedding matrix C maps each of the 27 characters (26 letters + the boundary token) into a 10-dimensional vector. The hidden layer has 200 neurons. The context window is 3 characters.

n_embd = 10
n_hidden = 200
block_size = 3
vocab_size = 27

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)

The problematic weights.

Now, the weight matrix for the hidden layer. This is where the problem starts:

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g)
b1 = torch.randn(n_hidden, generator=g) * 0.01

W1 has shape (30, 200). Each entry is drawn from 𝒩(0,1). The input to this layer — the concatenated embedding — has 30 dimensions. When we compute the pre-activation embcat @ W1 + b1, we are summing 30 random products. By the central limit theorem, the output will have standard deviation 305.5.

That means the pre-activation values will typically range from 15 to +15. And tanh(15)=0.9999999..., which is as saturated as it gets.

# Forward pass.
emb = C[Xb]                          # (batch, 3, 10).
embcat = emb.view(emb.shape[0], -1)  # (batch, 30).
hpreact = embcat @ W1 + b1           # (batch, 200).

Let's look at the pre-activation values.

print(hpreact.min().item(), hpreact.max().item())
# Something like -18.3, 17.9.

These values are enormous. Now apply tanh.

h = torch.tanh(hpreact)  # (batch, 200) — activation

And inspect the result.

print(h.mean().item(), h.std().item())
# mean 0.0, std 0.99.

A standard deviation of 0.99 for a function bounded between 1 and +1 means almost everything is at the extremes. The histogram would show two spikes at 1 and +1, with almost nothing in between.

Why this kills learning.

The gradient of tanh is.

ddxtanh(x)=1tanh2(x)

When tanh(x)±1, the gradient is 11=0. During backpropagation, the gradient flowing through a saturated neuron is multiplied by this near-zero factor. The signal dies. The weight connected to that neuron receives essentially no gradient, so it does not update. The neuron is effectively dead — not permanently, but stuck until some other change in the network happens to push its pre-activation back into the active region.

# The gradient factor at each neuron
grad_factor = 1 - h**2
print(f"Fraction with gradient < 0.01: {(grad_factor < 0.01).float().mean().item():.2%}")
# Something like 89%.

89% of the neurons have gradients smaller than 0.01. Only 12% of the network is actually learning. This is catastrophic for a deep network — if this happens at every layer, the compound effect is exponential.

The fix, scale the weights.

The solution is to reduce the magnitude of W1 so that the pre-activation stays in the linear region of tanh (roughly 1.5 to +1.5):

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * 0.2

Multiplying by 0.2 shrinks the weight values. Now the pre-activation standard deviation drops from ~5.5 to ~1.1, and the tanh output uses the full range without saturating.

But 0.2 is a magic number. We found it by trial and error. What we really want is a principled rule.

The principled fix, Kaiming initialization.

The variance analysis tells us that if the input has n dimensions and each weight has variance σw2, then the output has variance n·σw2. To keep the output variance equal to the input variance (1), we need.

σw=1nin.

For tanh, we also need a gain factor because tanh is contractive — it reduces the standard deviation of its input. PyTorch uses a gain of 5/3 for tanh.

σw=5/3nin.

In code.

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / (n_embd * block_size)**0.5

With n_embd * block_size = 30, this gives σw=(5/3)/300.305. The pre-activation values now have standard deviation close to 1. The tanh output has a healthy spread — some saturation (around 5%), which is normal and even desirable, but not the 89% we had before.

The output layer.

The output layer needs a different treatment.

W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0

We scale W2 by 0.01, which makes the initial logits very small. Small logits mean the softmax output is close to a uniform distribution — 1/270.037 for each character. This is exactly what we want at initialization: the network starts with no confidence, and the initial loss is close to the theoretical minimum of log(1/27)3.30.

Without this scaling, the initial logits can be large, producing a confident but wrong distribution. The loss starts high (we see values like 27 instead of 3.3 — the "hockey stick" in the loss curve), and the first many training steps are wasted just learning to be less confident.

b2 = torch.randn(vocab_size, generator=g) * 0

The bias b2 is set to zero. There is no reason for the network to prefer any character over another before it has seen the data.

Putting it together.

# Good initialization
C  = torch.randn((vocab_size, n_embd), generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / (n_embd * block_size)**0.5
b1 = torch.randn(n_hidden, generator=g) * 0.01
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0

parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

Every weight has a deliberate scale. Nothing is left to chance. The embedding C is standard normal (it will be learned). W1 is Kaiming-scaled for tanh. b1 is small. W2 is small to keep logits near zero. b2 is zero.

The training loop is then straightforward.

for i in range(200000):
    ix = torch.randint(0, Xtr.shape[0], (32,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix]

    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)

    for p in parameters:
        p.grad = None
    loss.backward()

    lr = 0.1 if i < 100000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

Each line in the forward pass is a direct consequence of the choices above: embed, concatenate, linear transform, normalize through tanh, project to logits, compute loss. And each weight was initialized with a specific purpose: to keep the signal flowing through the network at a healthy magnitude from step zero.

The lesson is simple. The weights are not just numbers to optimize — their initial values determine whether the optimization can even begin. A neural network with saturated activations is a network that cannot hear its own gradients. Fixing the saturation is not an optimization trick; it is a prerequisite for learning to happen at all.