Contains basic layers used both in baseline and reformer
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
ff = Residual(PreNorm(d, FeedForward(d)))
out = ff(x)
assert (bs, sl, d) == out.size()
out.shape
fix_emb = TransformerEmbedding(256, 64, pos_enc='fixed')
abs_emb = TransformerEmbedding(256, 64, pos_enc='absolute')
axl_emb = TransformerEmbedding(256, 64, pos_enc='axial', axial_shape=(32,16))
print('Total number of parameters in embedding layer')
print(f'Fixed: {total_params(fix_emb)[0]}')
print(f'Absolute: {total_params(abs_emb)[0]}')
print(f'Axial: {total_params(axl_emb)[0]}')