Welcome to Software Development on Codidact!
Will you help us build our independent community of developers helping developers? We're small and trying to grow. We welcome questions about all aspects of software development, from design to code to QA and more. Got questions? Got answers? Got code you'd like someone to review? Please join us.
What binds together the loss function, the optimizer and the model in PyTorch?
To have a successful "training session", these should cooperate. I.e.
- model generates an output from the input data,
- the loss function says, how bad or good is it,
- and then the optimizer tunes the model, trying to create better output.
Thus, these 3 essential entities must work together.
However, my pre-AI, mostly OO-accustomed brain regularly hits walls any time, if I try to understand some pytorch example code. The reason is very simple: in the examples I have seen until now, they are actually never cooperating. The most typical version is, that the loss function actually calculates a loss, but its result is nowhere used. It is an example code fragment, found and downloaded from the internet:
# create model
model = nn.Sequential(
nn.Linear(60, 60),
nn.ReLU(),
nn.Linear(60, 30),
nn.ReLU(),
nn.Linear(30, 1),
nn.Sigmoid()
)
# Train the model
n_epochs = 200
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.train()
for epoch in range(n_epochs):
for X_batch, y_batch in loader:
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
As it is clearly visible, the loss()
function "does not know" about the optimizer, it does not know about the model. It is called by the predicted and the real data, and it calculates a difference.
This information is even used in the loss.backward()
line, but how?
How does it the .backward()
method, what to do? How does the optimizer
object know about the loss results? The code line optimizer = optim.SGD(model.parameters(), lr=0.1)
binds the model and the optimizer together, but loss is "wired" to nowhere.
My first spot impression of this code that it simply could not work. Yes it is works.
How?
1 answer
Your suspicion is right that PyTorch relies heavily on (hidden) global states.
In your specific example, loss.backward()
computes (all) gradients and accumulates them directly in the grad
attribute of each parameter. You can verify this by printing the grad attributes of the parameters before and after the backward
call using the following list comprehension [par.grad for par in model.parameters()]
(note that these can be big matrices and it might be useful to select only a few elements to reduce the amount of generated output).
Once the gradients have been stored in the grad
attributes, the optimiser, which has access to (a subset of) the parameters by means of the references you pass upon construction, can use these gradients to perform the desired updates.
Also, note my emphasis on accumulate: the computed gradients are added to whatever was stored in the grad
attribute previously. This is why it is so important to call optimizer.zero_grad()
(or model.zero_grad()
) before starting gradient computations.
I hope this gives you the desired insights in how this works.
PS: there is also an official PyTorch tutorial on the autograd system.
PPS: there is more global state in how the autograd works, but I think that would be something for a different question.
1 comment thread