train_history, valid_history = [], []
for i in range(n_epochs):
# Shuffle before mini-batch split.
indices = torch.randperm(x[0].size(0))
x_ = torch.index_select(x[0], dim=0, index=indices)
y_ = torch.index_select(y[0], dim=0, index=indices)
# |x_| = (total_size, input_dim)
# |y_| = (total_size, output_dim)
x_ = x_.split(batch_size, dim=0)
y_ = y_.split(batch_size, dim=0)
# |x_[i]| = (batch_size, input_dim)
# |y_[i]| = (batch_size, output_dim)
train_loss, valid_loss = 0, 0
y_hat = []
for x_i, y_i in zip(x_, y_):
# |x_i| = |x_[i]|
# |y_i| = |y_[i]|
y_hat_i = model(x_i)
loss = F.mse_loss(y_hat_i, y_i)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += float(loss)
train_loss = train_loss / len(x_)
# You need to declare to PYTORCH to stop build the computation graph.
with torch.no_grad():
# You don't need to shuffle the validation set.
# Only split is needed.
x_ = x[1].split(batch_size, dim=0)
y_ = y[1].split(batch_size, dim=0)
valid_loss = 0
for x_i, y_i in zip(x_, y_):
y_hat_i = model(x_i)
loss = F.mse_loss(y_hat_i, y_i)
valid_loss += loss
y_hat += [y_hat_i]
valid_loss = valid_loss / len(x_)
# Log each loss to plot after training is done.
train_history += [train_loss]
valid_history += [valid_loss]
if (i + 1) % print_interval == 0:
print('Epoch %d: train loss=%.4e valid_loss=%.4e lowest_loss=%.4e' % (
i + 1,
train_loss,
valid_loss,
lowest_loss,
))
if valid_loss <= lowest_loss:
lowest_loss = valid_loss
lowest_epoch = i
# 'state_dict()' returns model weights as key-value.
# Take a deep copy, if the valid loss is lowest ever.
best_model = deepcopy(model.state_dict())
else:
if early_stop > 0 and lowest_epoch + early_stop < i + 1:
print("There is no improvement during last %d epochs." % early_stop)
break
print("The best validation loss from epoch %d: %.4e" % (lowest_epoch + 1, lowest_loss))
# Load best epoch's model.
model.load_state_dict(best_model)