Utils for basic training loop.

class AverageMeter[source]

AverageMeter(store_vals=False, store_avgs=False)

def train_step(batch, model, optimizer, loss_func, scheduler):
    xb = batch.to(device)
    out = model(xb)

    loss, extra = loss_func(out[0], xb, *out[1:])
    
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    optimizer.zero_grad()
    return loss.item(), extra
def eval_step(batch, model, loss_func):
    xb = batch.to(device)
    with torch.no_grad():
        out = model(xb)
        loss, extra = loss_func(out[0], xb, *out[1:])
    return loss.item(), extra
def fit(n_epoch, model, train_dl, valid_dl, optimizer, loss_func, scheduler=None):
    
    steps_per_epoch = len(train_dl)
    total_steps = n_epoch * steps_per_epoch
    train_losses = np.ones((total_steps, 3))
    valid_losses = np.ones((n_epoch, 3))
    for e in trange(n_epoch):
        
        model.train()
        train_pbar = tqdm(train_dl, leave=False)
        for step, batch in enumerate(train_pbar):
            total_step = (e*steps_per_epoch)+step
            loss, h = train_step(batch, model, optimizer, loss_func, scheduler)
            train_losses[total_step, :] = np.array(h)
            train_pbar.set_description(f"{loss:.2f}")

        model.eval()
        avg_valid_loss = np.zeros(3)
        for step, batch in enumerate(valid_dl):
            loss, h = eval_step(batch, model, loss_func)
            avg_valid_loss += (np.array(h)-avg_valid_loss) / (step+1)
        valid_losses[e, :] = avg_valid_loss
    return train_losses, valid_losses