Memory efficient transformer

Helper classes

class Chunk[source]

Chunk(n_chunks:int, fn:Module, dim:int=-1) :: Module

Applies fn to input chunked along dim

class ChunkedFeedForward[source]

ChunkedFeedForward(d:int, d_ff:int=None, n_chunks:int=1, dropout:float=0.0, dim:int=-1) :: Module

Applies positionwise feed-forward layer to input chunced along dim

bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
ff  = ChunkedFeedForward(d, n_chunks=8, dim=1)
out = ff(x)
assert out.size() == (bs, sl, d)

class Deterministic[source]

Deterministic(net:Module) :: Module

Wrapper module to ensure determinism for backward pass following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html

 

class ReversibleBlock[source]

ReversibleBlock(f:Module, g:Module, depth=None, send_signal=False) :: Module

Applies f and g in reversible manner. Avoids storing outputs for backpropagation

bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
# revblock is called on twin x
x2 = torch.cat([x, x], dim=-1)
attn = Attention(d)
ff = ChunkedFeedForward(d, n_chunks=8, dim=-2)
revblock = ReversibleBlock(attn, ff)
out = revblock(x2)
assert out.size() == (bs, sl, d*2)
# no grads are stored
out = torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
try: out.mean().backward()
except RuntimeError as e: print(e)
element 0 of tensors does not require grad and does not have a grad_fn

class IrreversibleBlock[source]

IrreversibleBlock(f, g) :: Module

Mimics ReversibleBlock computation but gradients are computed as ussual

attn = Attention(d)
ff = ChunkedFeedForward(d, n_chunks=8, dim=-2)
irrevblock = IrreversibleBlock(attn, ff)
out = irrevblock(x2)
assert out.size() == (bs, sl, d*2)

class ReversibleSequence[source]

ReversibleSequence(blocks, rev_thres=0, send_signal=False) :: Module

Stack of ReversibleBlocks constructed from blocks.Applies ReversibleBlocks if sequence length is > rev_thres or else IrreversibleBlocks.

bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
x2 = torch.cat([x, x], dim=-1)
blocks = []
for i in range(2):
    f = PreNorm(d, Attention(d))
    g = PreNorm(d, FeedForward(d))
    blocks.append(nn.ModuleList([f, g]))
layers = ReversibleSequence(nn.ModuleList(blocks))
out = layers(x2)
assert out.size() == (bs, sl, 2*d)
bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
x2 = torch.cat([x, x], dim=-1)
blocks = []
for i in range(2):
    f = PreNorm(d, LSHSelfAttention(d, bucket_size=16))
    g = PreNorm(d, FeedForward(d))
    blocks.append(nn.ModuleList([f, g]))
layers = ReversibleSequence(nn.ModuleList(blocks))
out = layers(x2, arg_route=(True, False), _reverse=False, _depth=1)
assert out.size() == (bs, sl, 2*d)
try: out.mean().backward()
except RuntimeError as e: print(e)
element 0 of tensors does not require grad and does not have a grad_fn

ReversibleTransformer

class ReversibleEncoder[source]

ReversibleEncoder(d_model:int, n_layers:int=6, n_heads:int=8, max_seq_len:int=512, ff_chunks:int=1, causal:bool=False, attn_dropout:float=0.0, post_attn_dropout:float=None, attn_bias:bool=False, ff_dropout:float=0.0, d_ff:int=None, prenorm:bool=True, final_norm:Module=None, rev_thres:int=0) :: Module

Stack of ReversibleBlocks

x = torch.randn(bs, sl, d)
m = ReversibleEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 64, 128])

class ReversibleDecoder[source]

ReversibleDecoder(d_model, n_layers=6, heads=8, max_seq_len=512, d_head=None, bucket_size=64, n_hashes=8, ff_chunks=1, attn_chunks=None, attn_dropout=0.0, post_attn_dropout=None, attn_bias:bool=False, ff_dropout=0.0, d_ff=None, prenorm=True, final_norm:Module=None, rev_thres=0) :: Module

Stack of ReversibleBlocks. Uses AdditiveAttention.

x = torch.randn(bs, sl, d)
m = ReversibleDecoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 64, 128])

class ReversibleLM[source]

ReversibleLM(vocab_sz:int, d_model:int, n_layers:int=6, n_heads:int=8, d_ff:int=None, ff_chunks:int=1, attn_dropout:float=0.1, ff_dropout:float=0.1, emb_dropout:float=0.1, tie_weights:bool=True, causal:bool=True, pos_enc:str='absolute', max_seq_len:int=512, axial_shape=None, axial_emb_dims=None, pad_idx:int=None, prenorm:bool=True, attn_bias:bool=False, rev_thres:int=0) :: Module

Reversible Transformer for language modelling

Parameters:

* vocab_sz: int
* d_model: int - inner dimension of the model
* n_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* ff_chunkes: int - number of chunks for FeedForward layer computation
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* causal: bool (default: True) - if True does causal masking automatically
* max_seq_len: int (default: 512)
* tie_weights: bool - if True target embedding weights are used for computation output projection
* prenorm: bool - wether to use PreNorm or PostNorm
* attn_bias: bool - if True projection layers attention modules will have bias
* pad_idx: int - padding token id, required for autogeneration of padding mask
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - required if 'axial' positional encoding are used, should be factors of
        max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model
* rev_thres: int - if (seq_len < rev_thres) applies irreversible blocks

Inputs:

* x - input ids, shape [bs, sl]
* mask - optional boolean mask, shape [bs, sl]

Returns:

* logits - target token logits, shape [bs, sl, vocab_sz]
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = ReversibleLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
torch.Size([4, 128, 256])

class ReversibleTransformer[source]

ReversibleTransformer(enc_vocab_sz, dec_vocab_sz, d_model, n_layers:int=6, n_enc_layers=None, n_dec_layers=None, n_heads=8, d_ff=None, ff_chunks:int=1, pad_idx=None, tie_weights=True, shared_emb=False, attn_dropout=0.1, ff_dropout=0.1, emb_dropout=0.1, prenorm=True, attn_bias=False, comb_attn=False, pos_enc='absolute', max_seq_len=512, axial_shape=None, axial_emb_dims=None) :: Module

Basic Transformer Encoder-Decoder model Parameters:

* enc_vocab_sz: int - source vocab size
* dec_vocab_sz: int - target vocab size
* d_model: int - inner dimension of the model
* n_enc_layers: int (default: 6)
* n_dec_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* ff_chunkes: int - number of chunks for FeedForward layer computation
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* max_seq_len: int (default: 512)
* prenorm: bool - whether to use PreNorm or PostNorm
* attn_bias: bool - whether to allow biases in attention projection layers
* pad_idx: int - padding token id, if pad_idx is provided, and no mask/context_mask are
        passed to forward method will be used to generate padding masks
* tie_weights: bool - if True target embedding weights are used for computation output projection
* shared_emb: bool - if True encoder and decoder will use shared embedding layer
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - required if 'axial' positional encoding are used, should be factors of
        max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model

Inputs:

* src - source input ids, shape [bs, src_sl]
* tgt - target input ids, shape [bs, tgt_sl]
* src_mask - optional boolean source mask, shape [bs, src_sl]
* tgt_mask - optional boolean target mask, shape [bs, tgt_sl]

Returns:

* logits - target token logits, shape [bs, tgt_sl, tgt_vocab_sz]
bs = 4
src_sl = 70
tgt_sl = 80
d = 64
src_vocab_sz = 256
tgt_vocab_sz = 256
src = torch.randint(src_vocab_sz, (bs, src_sl))
tgt = torch.randint(tgt_vocab_sz, (bs, tgt_sl))
model = ReversibleTransformer(src_vocab_sz, tgt_vocab_sz, d, n_enc_layers=2, n_dec_layers=2)
out = model(src, tgt)
assert (out.size() == (bs, tgt_sl, tgt_vocab_sz))
out.shape
torch.Size([4, 80, 256])

Transformer with LSH attention

class LSHEncoderBlock[source]

LSHEncoderBlock(d_model:int, n_heads:int=8, d_ff:int=None, attn_dropout:float=0.1, ff_dropout:float=0.1, causal:bool=False, attn_bias:bool=False, prenorm:bool=False, use_lsh:bool=True, n_hashes:int=8, bucket_size:int=64, seed:int=None) :: Module

Encoder block using ReformerAttention

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = LSHEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])
m = LSHEncoderBlock(d, use_lsh=False)
out = m(x)
assert (out.size() == (bs, sl, d))

class LSHEncoder[source]

LSHEncoder(d_model, n_layers=6, n_heads=8, d_ff=None, ff_dropout=0.1, attn_dropout=0.1, attn_bias=False, causal=False, prenorm=False, use_lsh:bool=True, final_norm=None, n_hashes:int=8, bucket_size:int=64, seed:int=None) :: Module

Stack of TransformerEncoderBlocks

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = LSHEncoder(d, n_layers=2)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])
m = LSHEncoder(d, n_layers=2, n_heads=4, use_lsh=False)
out = m(x)
assert (out.size() == (bs, sl, d))

class LSHLM[source]

LSHLM(vocab_sz:int, d_model:int, n_layers:int=6, n_heads:int=8, d_ff:int=None, attn_dropout:float=0.1, ff_dropout:float=0.1, emb_dropout:float=0.1, tie_weights:bool=True, causal:bool=True, pos_enc:str='absolute', max_seq_len:int=512, axial_shape:tuple=None, axial_emb_dims:tuple=None, pad_idx:int=None, prenorm:bool=False, attn_bias:bool=False, use_lsh:bool=True, n_hashes:int=8, bucket_size:int=64, seed:int=None) :: Module

Transformer for language modelling with LSH attention

Parameters:

* vocab_sz: int
* d_model: int - inner dimension of the model
* n_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* causal: bool (default: True) - if True does causal masking automatically
* max_seq_len: int (default: 512)
* tie_weights: bool - if True target embedding weights are used for computation output projection
* prenorm: bool - wether to use PreNorm or PostNorm
* attn_bias: bool - wether to allow biases in attention projection layers
* pad_idx: int - padding token id, required for autogeneration of padding mask
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - required if 'axial' positional encoding are used, should be factors of
        max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model
* use_slh: bool - parameter to switch between LSH and full attention
* n_hashes: int - number of hashing rounds for LSH
* bucket_size: int - input sequence length should be divisible by 2*bucket_size
* seed: int - for LSHAttention module

Inputs:

* x - input ids, shape [bs, sl]
* mask - optional boolean mask, shape [bs, sl]

Returns:

* logits - target token logits, shape [bs, sl, vocab_sz]
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = LSHLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
torch.Size([4, 128, 256])
model.use_lsh = True
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
%timeit model(x)
304 ms ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
model.use_lsh = False
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
%timeit model(x)
8.6 ms ± 325 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Reformer

class ReformerEncoder[source]

ReformerEncoder(d_model:int, n_layers:int=6, n_heads:int=8, max_seq_len:int=512, ff_chunks:int=1, causal:bool=False, attn_dropout:float=0.0, post_attn_dropout:float=None, attn_bias:bool=False, ff_dropout:float=0.0, d_ff:int=None, prenorm:bool=True, final_norm:Module=None, rev_thres:int=0, use_lsh:bool=True, n_hashes:int=8, bucket_size:int=64, seed:int=None) :: Module

Stack of ReversibleBlocks

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = ReformerEncoder(d, n_layers=2)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])

class ReformerLM[source]

ReformerLM(vocab_sz:int, d_model:int, n_layers:int=6, n_heads:int=8, d_ff:int=None, ff_chunks:int=1, attn_dropout:float=0.1, ff_dropout:float=0.1, emb_dropout:float=0.1, tie_weights:bool=True, causal:bool=True, pos_enc:str='axial', max_seq_len:int=512, axial_shape:tuple=None, axial_emb_dims:tuple=None, pad_idx:int=None, prenorm:bool=True, attn_bias:bool=False, use_lsh:bool=True, n_hashes:int=8, bucket_size:int=64, rev_thres:int=0, seed:int=None) :: Module

Reformer for language modelling. Uses LSH or full sharedQK attention

Parameters:

* vocab_sz: int
* d_model: int - inner dimension of the model
* n_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* ff_chunkes: int - number of chunks for FeedForward layer computation
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* causal: bool (default: True) - if True does causal masking automatically
* max_seq_len: int (default: 512)
* tie_weights: bool - if True target embedding weights are used for computation output projection
* prenorm: bool - wether to use PreNorm or PostNorm
* attn_bias: bool - wether to allow biases in attention projection layers
* pad_idx: int - padding token id, required for autogeneration of padding mask
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - required if 'axial' positional encoding are used, should be factors of
        max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model
* rev_thres: int - if (seq_len < rev_thres) applies irreversible blocks
* use_slh: bool - parameter to switch between LSH and full attention
* n_hashes: int - number of hashing rounds for LSH
* bucket_size: int - input sequence length should be divisible by 2*bucket_size
* seed: int - for LSHAttention module

Inputs:

* x - input ids, shape [bs, sl]
* mask - optional boolean mask, shape [bs, sl]

Returns:

* logits - target token logits, shape [bs, sl, vocab_sz]
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = ReformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
torch.Size([4, 128, 256])

Check cached buckets:

{'buckets:0': tensor([[ 0,  0,  1,  ..., 15, 15, 14],
        [ 0,  0,  0,  ..., 14, 15, 15],
        [ 0,  0,  0,  ..., 14, 15, 14],
        [ 0,  0,  1,  ..., 14, 15, 14]])}
torch.Size([4, 1024])
{'buckets:1': tensor([[ 0,  0,  0,  ..., 15, 15, 14],
        [ 0,  0,  0,  ..., 15, 14, 15],
        [ 1,  1,  1,  ..., 14, 14, 14],
        [ 0,  0,  0,  ..., 14, 15, 15]])}
torch.Size([4, 1024])

LSHAttention execution time depends on number of hashing rounds

print(f'Number of hashing rounds {model._n_hashes}')
%timeit model(x)
Number of hashing rounds 8
304 ms ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
model.n_hashes = 1
print(f'Number of hashing rounds {model.n_hashes}')
%timeit model(x)
Number of hashing rounds 1
74.7 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

reformer_lm_splits[source]

reformer_lm_splits(model)

Splits ReformerLM model into groups for differential learning rates.

None[source]