Utils for basic training loop.
def train_step(batch, model, optimizer, loss_func, scheduler):
xb = batch.to(device)
out = model(xb)
loss, extra = loss_func(out[0], xb, *out[1:])
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
return loss.item(), extra
def eval_step(batch, model, loss_func):
xb = batch.to(device)
with torch.no_grad():
out = model(xb)
loss, extra = loss_func(out[0], xb, *out[1:])
return loss.item(), extra
def fit(n_epoch, model, train_dl, valid_dl, optimizer, loss_func, scheduler=None):
steps_per_epoch = len(train_dl)
total_steps = n_epoch * steps_per_epoch
train_losses = np.ones((total_steps, 3))
valid_losses = np.ones((n_epoch, 3))
for e in trange(n_epoch):
model.train()
train_pbar = tqdm(train_dl, leave=False)
for step, batch in enumerate(train_pbar):
total_step = (e*steps_per_epoch)+step
loss, h = train_step(batch, model, optimizer, loss_func, scheduler)
train_losses[total_step, :] = np.array(h)
train_pbar.set_description(f"{loss:.2f}")
model.eval()
avg_valid_loss = np.zeros(3)
for step, batch in enumerate(valid_dl):
loss, h = eval_step(batch, model, loss_func)
avg_valid_loss += (np.array(h)-avg_valid_loss) / (step+1)
valid_losses[e, :] = avg_valid_loss
return train_losses, valid_losses