Gradient Clipping
Description
Gradient Clipping mitigates the exploding gradients problem by clipping the gradients during backpropagation so that they never exceed some threshold.
Info
This technique is generally used in recurrent neural networks, where using batch norm is tricky.
Example
import torch.nn as nn
for epoch in range(n_epochs):
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimizer.zero_grad()
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Info
If the original gradient vector is [0.9, 100.0], it points mostly in the direction of the second dimension; but once you clip it by norm, you get [0.00899964, 0.9999595], which would preserve the vector's orientation, but almost eliminate the first component.
import torch.nn as nn
for epoch in range(n_epochs):
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
optimizer.zero_grad()
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
loss.backward()
nn.utils.clip_grad_value_(model.parameters(), max_norm=1.0)
optimizer.step()
Info
If the original gradient vector is [0.9, 100.0], it points mostly in the direction of the second dimension; but once you clip it by value, you get [0.9, 1.0], which points roughly at the diagonal between the two axes.