Breno

Training a drifting model

The training of the drifting model is about approximating the data distribution.

What does that mean? A data distribution is the pattern hidden in a collection of examples. If you have a thousand clean photos — no rain, no noise, just clear scenes — the data distribution is the invisible shape that describes what "clean photos" look like. All the possible clean photos, not just the thousand you have.

A drifting model is a neural network f that takes random noise ε (think of static on a TV) and transforms it into something that looks like real data — in our case, a clean image. The word drifting comes from how it learns: the outputs drift, little by little, toward clean images during training.

Every time the optimizer takes a step — every SGD update — the network changes from fi to fi+1. For the same noise input ε, the output changes too:

xi+1=fi+1(ε)=xi+Δxi

That Δx (read "delta x") is the drift. It is a small displacement — how much the generated sample moved at this step. At the beginning of training, the network produces blurry, messy outputs — like a photo still covered in rain streaks. Each Δx moves the output a little closer to a clean image. The question is: where should Δx point?

This is what section 3.2 of Deng et al., 2026 answers.

The drifting field.

A field, in this context, is a rule that assigns a vector (a direction and a magnitude) to every point in space. Think of a weather map showing wind: at every location, there is an arrow saying "the wind blows this way, this fast." The drifting field V(x) is similar — at every generated sample x, it says "move this way, this much."

The field has two parts:

V(x)=V+(x)V(x).

V+(x) is attraction. It points from the generated sample toward the real data. Imagine each generated output is a rainy photo, and the clean images are the targets. V+ is the pull toward clarity — toward the clean versions.

V(x) is repulsion. It points from the generated sample toward the other generated samples. If all generated outputs looked the same — say, the same generic blurry patch — they would all collapse to one point. The repulsion prevents that. It pushes generated outputs away from each other, so they learn to produce diverse clean images, not just one.

The full field V=V+V means: be attracted to clean images, be repelled from other generated outputs. The minus sign in front of V turns the "direction toward others" into "direction away from others."

Let's see this with numbers. To keep it simple, imagine we work in a tiny 2D space where each "image" is just two numbers. We have three clean images (the targets) and three generated outputs (the network's current attempts).

clean = torch.tensor([[1.0, 1.0], [-1.0, 1.0], [0.0, -1.0]])  # clean images.
gen   = torch.tensor([[0.5, 0.0], [-0.5, 0.0], [0.0, 0.5]])    # generated (still rainy).

We pick one generated output, gen[0], sitting at (0.5,0) — think of it as a photo where some rain has been removed but not all. We ask: what does the field tell it to do?

First, the attraction. We compute the direction from our generated output to each clean image.

x = gen[0]            # our generated output at (0.5, 0.0).
diff = clean - x      # directions toward each clean image.

diff is now three vectors — one pointing toward each clean image. But not all clean images should pull equally. A nearby clean image should pull harder than a faraway one. This is where the kernel comes in.

A kernel is a function that measures how strongly two points interact based on their distance. The paper uses the following.

k(x,y)=exp(xyτ).

The symbol exp means the exponential function (e raised to a power). When the distance xy is small, the kernel is close to 1 (strong interaction). When the distance is large, the kernel is close to 0 (weak interaction). The parameter τ (tau) controls the range — how far the interaction reaches.

Think of τ as the radius of attention. A small τ means each generated output only listens to the very closest clean image. A large τ means it feels the pull of many clean images at once.

tau = 0.5
dist = diff.norm(dim=-1, keepdim=True)  # distance to each clean image.
k = torch.exp(-dist / tau)              # kernel: close = strong, far = weak.

Now we normalize the kernel weights so they sum to 1, like probabilities.

k_norm = k / k.sum(dim=0)

And the attraction is the weighted average direction.

V_plus = (k_norm * diff).sum(dim=0)

V_plus is a single vector. It points roughly toward the nearest clean image — the one that our still-rainy output is closest to becoming. The kernel made sure the closest clean target dominates the pull.

The repulsion is the same idea, but computed over the other generated outputs.

others = gen[1:]                          # the other generated outputs.
diff_neg = others - x                     # directions toward them.
dist_neg = diff_neg.norm(dim=-1, keepdim=True)
k_neg = torch.exp(-dist_neg / tau)
k_neg_norm = k_neg / k_neg.sum(dim=0)

V_minus = (k_neg_norm * diff_neg).sum(dim=0)

V_minus points toward the other generated outputs. The full field subtracts it.

V = V_plus - V_minus

So, V+ says "go toward the clean images." Subtracting V says "but also move away from the other generated outputs." Without the repulsion, all generated samples might converge to the same average clean image — a single blurry compromise. The repulsion forces diversity: each output finds its own clean target.

Why it stops when the job is done.

The field is anti-symmetric. That is a fancy word for a simple idea: if you swap the roles of clean images and generated outputs, the field reverses direction. Mathematically: Vp,q(x)=Vq,p(x).

This has an important consequence. When the generated distribution matches the data distribution — when the generated outputs are indistinguishable from the clean images — the attraction and the repulsion are computed over the same set of points. They become identical:

V+(x)=V(x)V(x)=V+(x)V(x)=0.

The field is zero everywhere. The outputs have no reason to move. This is the equilibrium — the resting point of the system. It means: training is done, the generated images are as clean as the real ones.

The training objective.

Now the question: how do we turn this field into something the optimizer can use? The paper uses a clever trick. They define a drifted target:

target=stopgrad(x+V(x)).

In code:

x = model(noise)              # current generated output (still rainy).
V = compute_drift(x, clean)   # the field: "move toward clean, away from others".
target = (x + V).detach()     # where the output should be after drifting.

The word stopgrad (or .detach() in PyTorch) means: treat this value as a frozen number, do not compute gradients through it. The target is fixed — it is just a point in space that says "this is where your output should go next."

The loss is the distance between the current position and the target:

=xtarget2=x(x+V)2=V2.

Look at what happened. The x cancels out, and the loss becomes V2 — the squared length of the drifting field. When V is large, the loss is large — the generated images are still far from clean. When V=0, the loss is zero — the images are clean.

loss = F.mse_loss(x, target)  # this equals ||V||².

But here is the subtle part. The gradient of this loss does not just shrink V. It updates the network weights so that the generated outputs move in the direction of V. The rainy images get a little cleaner.

step i:   output is rainy,      V says "remove more rain".
step i+1: output is less rainy, V is smaller now.
...
step N:   V ≈ 0, output is clean.

The optimizer (Adam, SGD) is the engine. The field V is the map. Each training step, the optimizer reads the map and pushes the outputs a little closer to clean images, one Δx at a time, until V vanishes and rain is gone.

That is section 3.2. The drifting field tells each sample where to go. The stopgrad target turns it into a loss. And the optimizer makes the samples actually move, one Δx at a time, until the generated images match the clean data distribution.

Reference. Deng, M., Li, H., Li, T., Du, Y., & He, K. (2026). Generative Modeling via Drifting.