Memory efficient transformer
bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
ff = ChunkedFeedForward(d, n_chunks=8, dim=1)
out = ff(x)
assert out.size() == (bs, sl, d)
bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
# revblock is called on twin x
x2 = torch.cat([x, x], dim=-1)
attn = Attention(d)
ff = ChunkedFeedForward(d, n_chunks=8, dim=-2)
revblock = ReversibleBlock(attn, ff)
out = revblock(x2)
assert out.size() == (bs, sl, d*2)
# no grads are stored
out = torch.stack(out.chunk(2, dim=-1)).mean(dim=0)
try: out.mean().backward()
except RuntimeError as e: print(e)
attn = Attention(d)
ff = ChunkedFeedForward(d, n_chunks=8, dim=-2)
irrevblock = IrreversibleBlock(attn, ff)
out = irrevblock(x2)
assert out.size() == (bs, sl, d*2)
bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
x2 = torch.cat([x, x], dim=-1)
blocks = []
for i in range(2):
f = PreNorm(d, Attention(d))
g = PreNorm(d, FeedForward(d))
blocks.append(nn.ModuleList([f, g]))
layers = ReversibleSequence(nn.ModuleList(blocks))
out = layers(x2)
assert out.size() == (bs, sl, 2*d)
bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
x2 = torch.cat([x, x], dim=-1)
blocks = []
for i in range(2):
f = PreNorm(d, LSHSelfAttention(d, bucket_size=16))
g = PreNorm(d, FeedForward(d))
blocks.append(nn.ModuleList([f, g]))
layers = ReversibleSequence(nn.ModuleList(blocks))
out = layers(x2, arg_route=(True, False), _reverse=False, _depth=1)
assert out.size() == (bs, sl, 2*d)
try: out.mean().backward()
except RuntimeError as e: print(e)
x = torch.randn(bs, sl, d)
m = ReversibleEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
x = torch.randn(bs, sl, d)
m = ReversibleDecoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = ReversibleLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
bs = 4
src_sl = 70
tgt_sl = 80
d = 64
src_vocab_sz = 256
tgt_vocab_sz = 256
src = torch.randint(src_vocab_sz, (bs, src_sl))
tgt = torch.randint(tgt_vocab_sz, (bs, tgt_sl))
model = ReversibleTransformer(src_vocab_sz, tgt_vocab_sz, d, n_enc_layers=2, n_dec_layers=2)
out = model(src, tgt)
assert (out.size() == (bs, tgt_sl, tgt_vocab_sz))
out.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = LSHEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
m = LSHEncoderBlock(d, use_lsh=False)
out = m(x)
assert (out.size() == (bs, sl, d))
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = LSHEncoder(d, n_layers=2)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
m = LSHEncoder(d, n_layers=2, n_heads=4, use_lsh=False)
out = m(x)
assert (out.size() == (bs, sl, d))
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = LSHLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
model.use_lsh = True
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
%timeit model(x)
model.use_lsh = False
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
%timeit model(x)
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = ReformerEncoder(d, n_layers=2)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = ReformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
Check cached buckets:
LSHAttention
execution time depends on number of hashing rounds
print(f'Number of hashing rounds {model._n_hashes}')
%timeit model(x)
model.n_hashes = 1
print(f'Number of hashing rounds {model.n_hashes}')
%timeit model(x)