Basic healper functions

Helper functions

General purpose utils

exists[source]

exists(val)

default[source]

default(val, d)

expand_dim1[source]

expand_dim1(x)

max_neg_value[source]

max_neg_value(tensor)

setattr_on[source]

setattr_on(model, attr, val, module_class)

Generative utils

top_p_filter[source]

top_p_filter(logits, top_p=0.9)

top_k_filter[source]

top_k_filter(logits, top_k=20)

LSH specific helpers

cache_method_decorator[source]

cache_method_decorator(cache_attr, cache_namespace, reexecute=False)

def cache_method_decorator(cache_attr, cache_namespace, reexecute = False):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
            namespace_str = str(default(key_namespace, ''))
            _cache = getattr(self, cache_attr)
            _keyname = f'{cache_namespace}:{namespace_str}'

            if fetch:
                val = _cache[_keyname]
                if reexecute:
                    fn(self, *args, **kwargs)
            else:
                val = fn(self, *args, **kwargs)
                if set_cache:
                    setattr(self, cache_attr, {**_cache, **{_keyname: val}})
            return val
        return wrapper
    return inner_fn

look_one_back[source]

look_one_back(x)

def look_one_back(x):
    x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
    return torch.cat([x, x_extra], dim=2)

chunked_sum[source]

chunked_sum(tensor, chunks=1)

def chunked_sum(tensor, chunks=1):
    *orig_size, last_dim = tensor.shape
    tensor = tensor.reshape(-1, last_dim)
    summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
    return torch.cat(summed_tensors, dim=0).reshape(orig_size)

sort_key_val[source]

sort_key_val(t1, t2, dim=-1)

def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

batched_index_select[source]

batched_index_select(values, indices)

def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

Profiling functions

Utility functions to assess model performance. Test functions with mod and input x.

mod = get_text_classifier(AWD_LSTM, vocab_sz=10_000, n_class=10)
x = torch.randint(0, 100, (3, 72))

do_cuda_timing[source]

do_cuda_timing(f, inp, context=None, n_loops=100)

Get timings of cuda modules. Note self_cpu_time_total is returned, but from experiments this appears to be similar/same to the total CUDA time

f : function to profile, typically an nn.Module inp : required input to f context : optional additional input into f, used for Decoder-style modules

model_performance[source]

model_performance(n_loops=5, model='arto', dls=None, n_epochs=1, lr=0.0005)

DEMO CODE ONLY! Run training loop to measure timings. Note that the models internally should be changed depending on the model you would like to use. You should also adjust the metrics you are monitoring

total_params[source]

total_params(m)

Give the number of parameters of a module and if it's trainable or not

Number of params for our test model:

total_params(mod)
(24336280, True)

Translation Callbacks

Callbacks used to ensuring training a translation model works. All 3 are needed

See notebook here for explanation of EOS shifting

class CombineInputOutputCallback[source]

CombineInputOutputCallback() :: Callback

Callback to combine the source (self.xb) and target (self.yb) into self.xb

class CombineInputOutputCallback(Callback):
    """
    Callback to combine the source (self.xb) and target (self.yb) into self.xb
    """
    def __init__(self): pass
    def before_batch(self): 
        self.learn.xb = (self.xb[0], self.yb[0])
class AssertAndCancelFit(Callback):
    "Cancels batch after backward to avoid opt.step()"
    def before_batch(self):
        assert len(self.learn.xb) == 2
        assert self.learn.xb[1] is self.learn.yb[0]
        raise CancelEpochException()

learn = synth_learner(cbs=[CombineInputOutputCallback(), AssertAndCancelFit()])
learn.fit(1)
epoch train_loss valid_loss time
0 00:00

class RemoveEOSCallback[source]

RemoveEOSCallback(eos_idx) :: Callback

Shift the target presented to the model during training to remove the "eos" token as we don't want the model to learn to translate EOS when it sees EOS.

In practice we actually mask the EOS token as due to batching the last token will often be a token, not EOS

class RemoveEOSCallback(Callback):
    """
        Shift the target presented to the model during training to remove the "eos" token as 
        we don't want the model to learn to translate EOS when it sees EOS.
        
        In practice we actually mask the EOS token as due to batching the last token will often be a <pad> token,
        not EOS
    """
    def __init__(self, eos_idx): self.eos_idx=eos_idx
    def before_batch(self):        
        eos_mask=(self.learn.xb[1]!=self.eos_idx)
        sz=torch.tensor(self.learn.xb[1].size())
        sz[1]=sz[1]-1
        self.learn.xb = (self.learn.xb[0], self.learn.xb[1][eos_mask].view((sz[0],sz[1])))

class LossTargetShiftCallback[source]

LossTargetShiftCallback() :: Callback

Shift the target shown to the loss to exclude the "bos" token as the first token we want predicted should be an actual word, not the "bos" token (as we have already given the model "bos" )

class LossTargetShiftCallback(Callback):
    """
        Shift the target shown to the loss to exclude the "bos" token as the first token we want predicted
        should be an actual word, not the "bos" token (as we have already given the model "bos" )
    """
    def __init__(self): pass
    def after_pred(self): 
        self.learn.yb = (self.learn.yb[0][:,1:],)
class TestLossShiftAndCancelFit(Callback):
    "Cancels batch after backward to avoid opt.step()"
    def after_pred(self):      
        o = self.learn.dls.one_batch()
        assert self.learn.yb[0].size()[1] == o[1].size()[1] - 1
        raise CancelEpochException()

learn = synth_learner(cbs=[LossTargetShiftCallback(), TestLossShiftAndCancelFit()])
learn.fit(1)
epoch train_loss valid_loss time
0 00:00

class PadBatchCallback[source]

PadBatchCallback(bucket_size:int=64, val:int=0, y_val:int=-100) :: Callback

Pads input and target sequences to multiple of 2*bucket_size

Translation Transform

class AddEOSID[source]

AddEOSID(eos_id, keep_size=True)

input_ids = torch.tensor([1,2,3,4])
add_eos = AddEOSID(0)
add_eos(input_ids)
LMTensorText([1, 2, 3, 0])

Loss functions

class LabelSmoothingCrossEntropy[source]

LabelSmoothingCrossEntropy(eps:float=0.1, reduction='mean') :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

class LabelSmoothingCrossEntropyFlat[source]

LabelSmoothingCrossEntropyFlat(*args, axis=-1, eps=0.1, reduction='mean', flatten=True, floatify=False, is_2d=True) :: BaseLoss

Same as LabelSmoothingCrossEntropy, but flattens input and target.

bs=4
sl=10
v=32
pred = torch.randn(bs, sl, v, requires_grad=True)
targ = torch.randint(v, (bs,sl))
i, j = torch.triu_indices(bs, sl, offset=(sl-bs+1))
targ[i,j] = -1
loss_func = LabelSmoothingCrossEntropyFlat(ignore_index=-1)
loss = loss_func(pred, targ)
loss.backward()
assert (torch.all(pred.grad == 0, dim=-1) == (targ==-1)).all()

Distributed

Learner.distrib_ctx[source]

Learner.distrib_ctx(cuda_id=None, sync_bn=True)

A context manager to adapt a learner to train in distributed data parallel mode.