Baseline transformer blocks and models
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
x = torch.randn(bs, sl, d)
m = TransformerEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
m = TransformerDecoder(d)
out = m(x, context)
assert (out.size() == (bs, sl, d))
out.shape
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = TransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
bs = 4
src_sl = 70
tgt_sl = 80
d = 64
src_vocab_sz = 256
tgt_vocab_sz = 256
src = torch.randint(src_vocab_sz, (bs, src_sl))
tgt = torch.randint(tgt_vocab_sz, (bs, tgt_sl))
model = Transformer(src_vocab_sz, tgt_vocab_sz, d, n_enc_layers=2, n_dec_layers=2)
out = model(src, tgt)
assert (out.size() == (bs, tgt_sl, tgt_vocab_sz))
out.shape
In memory-effiecient Transformer attention is computed on chunks of queries. Setting n_chunks = sl/c
, for input sequence length sl
and some constant c
ensures memory complexity of O(sl) but the more chunks used - the slower computation is. So on practice it's advised to set n_chunks
based on available memory budget.
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = ChunkedTransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape