Contains basic layers used both in baseline and reformer

Layer Wrappers

class Residual[source]

Residual(sublayer:Module) :: Module

Add skip-connection: out = x + sublayer(x)

class PostNorm[source]

PostNorm(d_model:int, sublayer:Module) :: Module

Adds LayerNorm after sublayer

class PreNorm[source]

PreNorm(d_model:int, sublayer:Module) :: Module

Adds LayerNorm before sublayer

Positional FeedForward

class FeedForward[source]

FeedForward(d_model:int, d_ff:int=None, dropout:float=0.0) :: Module

Simple positional feed-forward module with GELU activation function. If d_ff is None defaults to 4*d_model

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
ff  = Residual(PreNorm(d, FeedForward(d)))
out = ff(x)
assert (bs, sl, d) == out.size()
torch.Size([4, 128, 64])

Embedding Layer



Simple heuristic to suggest axial_shape givem max_seq_len (2 factors)


get_axial_dims(d_emb, n)

class AbsolutePositionalEmbedding[source]

AbsolutePositionalEmbedding(d_emb:int, max_seq_len:int) :: Module

Learnable absolute positional encodings

class FixedPositionalEmbedding[source]

FixedPositionalEmbedding(d_emb:int) :: Module

Fixed positional encodings

class TransformerEmbedding[source]

TransformerEmbedding(emb_sz:int, d_emb:int, max_seq_len:int=512, dropout:float=0.0, pos_enc:str='absolute', axial_shape:Tuple=None, axial_emb_dims:Tuple=None) :: Module

Combines token embedings with positional encodings pos_enc: str from {'absolute', 'fixed', 'axial'}

fix_emb = TransformerEmbedding(256, 64, pos_enc='fixed')
abs_emb = TransformerEmbedding(256, 64, pos_enc='absolute')
axl_emb = TransformerEmbedding(256, 64, pos_enc='axial', axial_shape=(32,16))
print('Total number of parameters in embedding layer')
print(f'Fixed:    {total_params(fix_emb)[0]}')
print(f'Absolute: {total_params(abs_emb)[0]}')
print(f'Axial:    {total_params(axl_emb)[0]}')
Total number of parameters in embedding layer
Fixed:    16384
Absolute: 49152
Axial:    17920