Communities

Writing
Writing
Codidact Meta
Codidact Meta
The Great Outdoors
The Great Outdoors
Photography & Video
Photography & Video
Scientific Speculation
Scientific Speculation
Cooking
Cooking
Electrical Engineering
Electrical Engineering
Judaism
Judaism
Languages & Linguistics
Languages & Linguistics
Software Development
Software Development
Mathematics
Mathematics
Christianity
Christianity
Code Golf
Code Golf
Music
Music
Physics
Physics
Linux Systems
Linux Systems
Power Users
Power Users
Tabletop RPGs
Tabletop RPGs
Community Proposals
Community Proposals
tag:snake search within a tag
answers:0 unanswered questions
user:xxxx search by author id
score:0.5 posts with 0.5+ score
"snake oil" exact phrase
votes:4 posts with 4+ votes
created:<1w created < 1 week ago
post_type:xxxx type of post
Search help
Notifications
Mark all as read See all your notifications »
Q&A

Welcome to Software Development on Codidact!

Will you help us build our independent community of developers helping developers? We're small and trying to grow. We welcome questions about all aspects of software development, from design to code to QA and more. Got questions? Got answers? Got code you'd like someone to review? Please join us.

Comments on What binds together the loss function, the optimizer and the model in PyTorch?

Post

What binds together the loss function, the optimizer and the model in PyTorch?

+2
−0

To have a successful "training session", these should cooperate. I.e.

  • model generates an output from the input data,
  • the loss function says, how bad or good is it,
  • and then the optimizer tunes the model, trying to create better output.

Thus, these 3 essential entities must work together.

However, my pre-AI, mostly OO-accustomed brain regularly hits walls any time, if I try to understand some pytorch example code. The reason is very simple: in the examples I have seen until now, they are actually never cooperating. The most typical version is, that the loss function actually calculates a loss, but its result is nowhere used. It is an example code fragment, found and downloaded from the internet:

# create model
model = nn.Sequential(
    nn.Linear(60, 60),
    nn.ReLU(),
    nn.Linear(60, 30),
    nn.ReLU(),
    nn.Linear(30, 1),
    nn.Sigmoid()
)

# Train the model
n_epochs = 200
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.train()
for epoch in range(n_epochs):
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

As it is clearly visible, the loss() function "does not know" about the optimizer, it does not know about the model. It is called by the predicted and the real data, and it calculates a difference.

This information is even used in the loss.backward() line, but how?

How does it the .backward() method, what to do? How does the optimizer object know about the loss results? The code line optimizer = optim.SGD(model.parameters(), lr=0.1) binds the model and the optimizer together, but loss is "wired" to nowhere.

My first spot impression of this code that it simply could not work. Yes it is works.

How?

History
Why does this post require attention from curators or moderators?
You might want to add some details to your flag.
Why should this post be closed?

1 comment thread

Global state managed by the library? (4 comments)
Global state managed by the library?

Perhaps the library manages state for you with global variables and singletons, similar to how some libraries have getContext and setContext functions, where every action done is through the use of a static function that doesn't accept any context argument, but instead uses the thread-local stored somewhere?

peterh‭ wrote 9 days ago

Possible, but there is very little and cryptic source about the details. Digging in torch docs, sometimes they are talking about some data propagation. I think, also torch.nn.Module is actually some object hierarchy. But there is nothing concrete.

peterh‭ wrote 9 days ago

Andreas demands justice for humanity‭ Ok I have found the solution. torch tensors are much more as the usual numpy/pandas things. Torch tensors have some... "shadow". A hidden data structure keeping track their gradient. Gradient is another tensor showing, how the given tensor alters the loss function. While you calculate the loss function, possibly on quite complex ways, torch tracks, which tensor did you count from which one, and how. But the details are not very clear to me yet. I think best is to leave this question, maybe someone will write an answer, until I learn the topic enough well to write a top-top self-answer.

Ah, hidden magic. That's lovely. Great option for a simple and concise API. Way worse to make somebody understand how things work.