Baseline transformer blocks and models

Mixin

class LMMixin[source]

LMMixin()

Mixin for language models

class EncDecMixin[source]

EncDecMixin()

Mixin for encoder-decoder models

Transformer blocks

Encoder

class TransformerEncoderBlock[source]

TransformerEncoderBlock(d_model:int, n_heads:int=8, d_ff:int=None, attn_dropout:float=0.1, ff_dropout:float=0.1, causal:bool=False, attn_bias:bool=False, prenorm:bool=False, shared_qk:bool=False) :: Module

Bacis transformer encoder block. Consists of multi-head attention and positional feedforward layers

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])

class TransformerEncoder[source]

TransformerEncoder(d_model, n_layers=6, n_heads=8, d_ff=None, ff_dropout=0.1, attn_dropout=0.1, attn_bias=False, causal=False, prenorm=False, shared_qk:bool=False, final_norm=None) :: Module

Stack of TransformerEncoderBlocks

x = torch.randn(bs, sl, d)
m = TransformerEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])

Decoder

class TransformerDecoderBlock[source]

TransformerDecoderBlock(d_model, n_heads=8, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, mask=None, attn_bias=False, prenorm=False) :: Module

Standart transformer decoder block. Consist of self-attention, encoder-decoder attention and positiona feed-forward alyers

class TransformerDecoderBlockV2[source]

TransformerDecoderBlockV2(d_model, n_heads=8, mask=None, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, attn_bias=False, prenorm=False) :: Module

Transformer decoder block using additive attention layer instead of self-attention followed by cross-attention

class TransformerDecoder[source]

TransformerDecoder(d_model, n_layers=6, n_heads=8, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, prenorm=False, comb_attn=False, attn_bias=False, final_norm=None) :: Module

Stack of TransformerDecoder layers

x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
m = TransformerDecoder(d)
out = m(x, context)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])

Language model

class TransformerLM[source]

TransformerLM(vocab_sz:int, d_model:int, n_layers:int=6, n_heads:int=8, d_ff:int=None, attn_dropout:float=0.1, ff_dropout:float=0.1, emb_dropout:float=0.1, tie_weights:bool=True, causal:bool=True, pos_enc:str='absolute', max_seq_len:int=512, axial_shape:tuple=None, axial_emb_dims:tuple=None, pad_idx:int=None, prenorm:bool=False, attn_bias:bool=False, shared_qk:bool=False) :: Module

Basic Transformer for language modelling

Parameters:

* vocab_sz: int
* d_model: int - inner dimension of the model
* n_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* causal: bool (default: True) - if True does causal masking automatically
* max_seq_len: int (default: 512)
* tie_weights: bool - if True target embedding weights are used for computation output projection
* prenorm: bool - wether to use PreNorm or PostNorm
* attn_bias: bool - wether to allow biases in attention projection layers
* pad_idx: int - padding token id, required for autogeneration of padding mask
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - [optional] should be factors of max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model

Inputs:

* x - input ids, shape [bs, sl]
* mask - optional boolean mask, shape [bs, sl]

Returns:

* logits - target token logits, shape [bs, sl, vocab_sz]
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = TransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
torch.Size([4, 128, 256])

transformer_lm_splits[source]

transformer_lm_splits(model)

Splits TransformerLM model into groups for differential learning rates.

Encoder-Decoder model

class Transformer[source]

Transformer(enc_vocab_sz, dec_vocab_sz, d_model, n_enc_layers=6, n_dec_layers=6, n_heads=8, d_ff=None, pad_idx=None, tie_weights=True, shared_emb=False, attn_dropout=0.1, ff_dropout=0.1, emb_dropout=0.1, prenorm=False, attn_bias=False, comb_attn=False, pos_enc='absolute', max_seq_len=512, axial_shape=None, axial_emb_dims=None) :: Module

Basic Transformer Encoder-Decoder model Parameters:

* enc_vocab_sz: int - source vocab size
* dec_vocab_sz: int - target vocab size
* d_model: int - inner dimension of the model
* n_enc_layers: int (default: 6)
* n_dec_layers: int (default: 6)
* heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* max_seq_len: int (default: 512)
* prenorm: bool - whether to use PreNorm or PostNorm
* attn_bias: bool - whether to allow biases in attention projection layers
* pad_idx: int - padding token id, if pad_idx is provided, and no mask/context_mask are
        passed to forward method will be used to generate padding masks
* tie_weights: bool - if True target embedding weights are used for computation output projection
* shared_emb: bool - if True encoder and decoder will use shared embedding layer
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - [optional] should be factors of max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model

Inputs:

* src - source input ids, shape [bs, src_sl]
* tgt - target input ids, shape [bs, tgt_sl]
* src_mask - optional boolean source mask, shape [bs, src_sl]
* tgt_mask - optional boolean target mask, shape [bs, tgt_sl]

Returns:

* logits - target token logits, shape [bs, tgt_sl, tgt_vocab_sz]
bs = 4
src_sl = 70
tgt_sl = 80
d = 64
src_vocab_sz = 256
tgt_vocab_sz = 256
src = torch.randint(src_vocab_sz, (bs, src_sl))
tgt = torch.randint(tgt_vocab_sz, (bs, tgt_sl))
model = Transformer(src_vocab_sz, tgt_vocab_sz, d, n_enc_layers=2, n_dec_layers=2)
out = model(src, tgt)
assert (out.size() == (bs, tgt_sl, tgt_vocab_sz))
out.shape
torch.Size([4, 80, 256])

transformer_splits[source]

transformer_splits(model)

[v0] Splits Transformer model into groups for differential learning rates.

Low Memory Transformer

In memory-effiecient Transformer attention is computed on chunks of queries. Setting n_chunks = sl/c, for input sequence length sl and some constant c ensures memory complexity of O(sl) but the more chunks used - the slower computation is. So on practice it's advised to set n_chunks based on available memory budget.

class LowMemEncoderBlock[source]

LowMemEncoderBlock(d_model:int, n_heads:int=8, d_ff:int=None, attn_dropout:float=0.1, ff_dropout:float=0.1, causal:bool=False, attn_bias:bool=False, prenorm:bool=False, shared_qk:bool=False, attn_chunks:int=1) :: Module

Low memory transformer encoder block. Consists of chunked multi-head attention and positional feedforward layers

class LowMemEncoder[source]

LowMemEncoder(d_model, n_layers=6, n_heads=8, d_ff=None, ff_dropout=0.1, attn_dropout=0.1, attn_bias=False, causal=False, prenorm=False, shared_qk:bool=False, final_norm=None, attn_chunks:int=1) :: Module

Stack of LowMemEncoderBlocks

class ChunkedTransformerLM[source]

ChunkedTransformerLM(vocab_sz:int, d_model:int, n_layers:int=6, n_heads:int=8, d_ff:int=None, attn_chunks:int=1, attn_dropout:float=0.1, ff_dropout:float=0.1, emb_dropout:float=0.1, tie_weights:bool=True, causal:bool=True, pos_enc:str='absolute', max_seq_len:int=512, axial_shape:tuple=None, axial_emb_dims:tuple=None, pad_idx:int=None, prenorm:bool=False, attn_bias:bool=False, shared_qk:bool=False) :: Module

Basic Transformer for language modelling

Parameters:

* vocab_sz: int
* d_model: int - inner dimension of the model
* n_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* attn_chunks: int - number of queries chunks for memory-efficient attention
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* causal: bool (default: True) - if True does causal masking automatically
* max_seq_len: int (default: 512)
* tie_weights: bool - if True target embedding weights are used for computation output projection
* prenorm: bool - wether to use PreNorm or PostNorm
* attn_bias: bool - wether to allow biases in attention projection layers
* pad_idx: int - padding token id, required for autogeneration of padding mask
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - [optional] should be factors of max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model

Inputs:

* x - input ids, shape [bs, sl]
* mask - optional boolean mask, shape [bs, sl]

Returns:

* logits - target token logits, shape [bs, sl, vocab_sz]
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = ChunkedTransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
torch.Size([4, 128, 256])

None[source]