Basic healper functions
   
    
    
    
    
   
def cache_method_decorator(cache_attr, cache_namespace, reexecute = False):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
            namespace_str = str(default(key_namespace, ''))
            _cache = getattr(self, cache_attr)
            _keyname = f'{cache_namespace}:{namespace_str}'
            if fetch:
                val = _cache[_keyname]
                if reexecute:
                    fn(self, *args, **kwargs)
            else:
                val = fn(self, *args, **kwargs)
                if set_cache:
                    setattr(self, cache_attr, {**_cache, **{_keyname: val}})
            return val
        return wrapper
    return inner_fn
def look_one_back(x):
    x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
    return torch.cat([x, x_extra], dim=2)
def chunked_sum(tensor, chunks=1):
    *orig_size, last_dim = tensor.shape
    tensor = tensor.reshape(-1, last_dim)
    summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
    return torch.cat(summed_tensors, dim=0).reshape(orig_size)
def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
Utility functions to assess model performance. Test functions with mod and input x.
mod = get_text_classifier(AWD_LSTM, vocab_sz=10_000, n_class=10)
x = torch.randint(0, 100, (3, 72))
Number of params for our test model:
total_params(mod)
Translation Callbacks
Callbacks used to ensuring training a translation model works. All 3 are needed
See notebook here for explanation of EOS shifting
class CombineInputOutputCallback(Callback):
    """
    Callback to combine the source (self.xb) and target (self.yb) into self.xb
    """
    def __init__(self): pass
    def before_batch(self): 
        self.learn.xb = (self.xb[0], self.yb[0])
class AssertAndCancelFit(Callback):
    "Cancels batch after backward to avoid opt.step()"
    def before_batch(self):
        assert len(self.learn.xb) == 2
        assert self.learn.xb[1] is self.learn.yb[0]
        raise CancelEpochException()
learn = synth_learner(cbs=[CombineInputOutputCallback(), AssertAndCancelFit()])
learn.fit(1)
class RemoveEOSCallback(Callback):
    """
        Shift the target presented to the model during training to remove the "eos" token as 
        we don't want the model to learn to translate EOS when it sees EOS.
        
        In practice we actually mask the EOS token as due to batching the last token will often be a <pad> token,
        not EOS
    """
    def __init__(self, eos_idx): self.eos_idx=eos_idx
    def before_batch(self):        
        eos_mask=(self.learn.xb[1]!=self.eos_idx)
        sz=torch.tensor(self.learn.xb[1].size())
        sz[1]=sz[1]-1
        self.learn.xb = (self.learn.xb[0], self.learn.xb[1][eos_mask].view((sz[0],sz[1])))
class LossTargetShiftCallback(Callback):
    """
        Shift the target shown to the loss to exclude the "bos" token as the first token we want predicted
        should be an actual word, not the "bos" token (as we have already given the model "bos" )
    """
    def __init__(self): pass
    def after_pred(self): 
        self.learn.yb = (self.learn.yb[0][:,1:],)
class TestLossShiftAndCancelFit(Callback):
    "Cancels batch after backward to avoid opt.step()"
    def after_pred(self):      
        o = self.learn.dls.one_batch()
        assert self.learn.yb[0].size()[1] == o[1].size()[1] - 1
        raise CancelEpochException()
learn = synth_learner(cbs=[LossTargetShiftCallback(), TestLossShiftAndCancelFit()])
learn.fit(1)
input_ids = torch.tensor([1,2,3,4])
add_eos = AddEOSID(0)
add_eos(input_ids)
bs=4
sl=10
v=32
pred = torch.randn(bs, sl, v, requires_grad=True)
targ = torch.randint(v, (bs,sl))
i, j = torch.triu_indices(bs, sl, offset=(sl-bs+1))
targ[i,j] = -1
loss_func = LabelSmoothingCrossEntropyFlat(ignore_index=-1)
loss = loss_func(pred, targ)
loss.backward()
assert (torch.all(pred.grad == 0, dim=-1) == (targ==-1)).all()