bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
shared_proj = SharedQKAttnInProj(d)
q1, k1, v1 = shared_proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
assert q1 is k1
q1.shape, k1.shape, v1.shape
Scaled dot-product attention is calculated as:
$$\textbf {Attention}(Q,K,V) = \textbf {softmax}({QK^T\over\sqrt d_k})V $$
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ScaledDotProdAttention(d, 4)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape
attn_func = ScaledDotProdAttention(d, 4, shared_qk=True)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape
q = torch.randn(bs, sl, d).cuda()
k = torch.randn(bs, sl, d).cuda()
v = torch.randn(bs, sl, d).cuda()
attn_func = ScaledDotProdAttention(d, 4, shared_qk=True)
out = attn_func(q, k, v)
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
out = attn(x, context)
assert (bs, sl, d) == out.size()
out.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d, shared_qk=True)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
e_msg = "Causal masking error"
attn = Attention(d, causal=True, dropout=0)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2)
# all elements in first half are equal despite second half is defferent
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = Attention(d, causal=False, dropout=0)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2, mask=mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = Attention(d, causal=False, dropout=0)
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
context_mask = torch.ones(bs, sl)
# mask out second half of context
context_mask[:, sl//2:] = 0
context_mask = context_mask.bool()
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x, context2, context_mask=context_mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1, out2), e_msg
# all output values are different for different context
out1 = attn(x, context)
out2 = attn(x, context2)
assert not (out1 == out2).any()
torch.manual_seed(842)
bs = 4
sl = 16
csl = sl + 16
d = 64
x = torch.rand(bs, sl, d)
context = torch.rand(bs, csl, d)
mask = torch.ones(bs, sl)
mask[:, -5:] = 0
context_mask = torch.ones(bs, csl)
context_mask[:, -10:] = 0
mask, context_mask = mask.bool(), context_mask.bool()
attn = Attention(d, store_attention=True)
out = attn(x, context, mask=mask, context_mask=context_mask)
attention = attn.attn.attention
assert (bs, sl, d) == out.size()
assert attention.size() == (bs, attn.attn.n_heads, sl, csl)
# zeros for masked keys and "don't cares" for masked queries
plt.matshow(attention[0,0]);
Customized _checkpoint
and _ChunkedAttnCptFunction
to handle non-tensor args. See https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html for source implementation.
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ChunkedDotProdAttention(d, 4, n_chunks=8, causal=True, shared_qk=True)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
cattn = ChunkedAttention(d, n_chunks=10)
out = cattn(x)
assert (bs, sl, d) == out.size()
out.shape
def time_fwd_bwd(f, x):
loss = f(x).sum()
loss.backward()
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AdditiveInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == (bs, x.size(1)+context.size(1), d)
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = AdditiveAttention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
out = attn(x, context)
assert (bs, sl, d) == out.size()
out.shape
e_msg = "Causal masking error"
attn = AdditiveAttention(d, causal=True, dropout=0)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2)
# all elements in first half are equal despite second half is defferent
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = AdditiveAttention(d, causal=False, dropout=0)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2, mask=mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = Attention(d, causal=False, dropout=0)
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
context_mask = torch.ones(bs, sl)
# mask out second half of context
context_mask[:, sl//2:] = 0
context_mask = context_mask.bool()
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x, context2, context_mask=context_mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1, out2), e_msg
# all output values are different for different context
out1 = attn(x, context)
out2 = attn(x, context2)
assert not (out1 == out2).any()
LSH attention from Reformer: The Efficient Transformer. Based on lucidrains/reformer-pytorch, but simpliefied and refactored. Uses shared keys and queries, but requires both to be passed as input (even though they are identical).
Test LSH-attention layer. Note: d_model
is infered from input. Assumes shared key and query, but accepts both as input.
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
shared_proj = SharedQKAttnInProj(d)
q, k, v = shared_proj(x)
lsh_attn = LSHAttention()
out, _, _ = lsh_attn(q, k, v)
assert (bs, sl, d) == out.size()
out.shape
lsh_attn = LSHAttention(seed=123)
lsh_attn1 = LSHAttention(seed=123)
assert all_equal(lsh_attn(q, k, v), lsh_attn1(q,k,v))
Performs multihead LSHAttention
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
attn = LSHSelfAttention(d, seed=123, out_dropout=0.1, dropout=0.1, dropout_hash=0.1)
assert all_equal(attn(x), attn(x))
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
Note that unlike the testing for the standard transformer, we can't draw new vectors for our change input since this would impact the clustering of the vectors in the LSH-algorithm. If we instead scale by a constant factor, the angular based clustering is not affected, even though the values have changed.
e_msg = "Causal masking error"
attn = LSHSelfAttention(d, causal=True, dropout=0, seed=123)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = x2[:, sl//2:, :]*2
out2 = attn(x2)
assert torch.allclose(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = LSHSelfAttention(d, causal=False, dropout=0, seed=123)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = x2[:, sl//2:, :]*2
out2 = attn(x2, mask=mask)
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = LSHSelfAttention(d, causal=False, dropout=0, seed=123)
x = torch.randn(bs, sl, d)
Passing in context=x should not alter the result, as compared to no context:
out0 = attn(x,)
out1 = attn(x, context=x)
assert all_equal(out0, out1)
Mask second half of context
context = x.clone() # cloning x for context
context_mask = torch.ones(bs, sl).bool()
context_mask[:, sl//2:] = False
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, -1:, :] = context2[:, -1:, :]*2 # scaling to not affect clustering, relevant here?
#context2[:, sl//2:, :] = torch.randn(bs, sl//2, d) # new random data
out2 = attn(x, context2, context_mask=context_mask)
out1[0], out2[0]
(out1==out2).sum()
#assert all_equal(out1, out2), e_msg
out1 = attn(x, context)
out2 = attn(x, context2)
#assert not (out1 == out2).any()
Reformer attention calculates multihead attention with shared keys and queries, and allows switching between full Attention
or LSHAttention
at creation, but not during inference or training.
bs = 4
sl = 128
d = 512
x = torch.randn(bs, sl, d)
attn_lsh = ReformerAttention(d, lsh_attention=True)
out = attn_lsh(x)
assert (bs, sl, d) == out.size()
out.shape
attn_full = ReformerAttention(d, lsh_attention=False)
out = attn_full(x)
assert (bs, sl, d) == out.size()
out.shape
The state dicts of full and lsh attention are identical:
[(k, v.shape) for k, v in attn_lsh.state_dict().items()]
[(k, v.shape) for k, v in attn_full.state_dict().items()]
ReformerAttentionV2 containes both LSHAttention
and ScaledDotProdAttention
and which one to use is determined by self.lsh_attention
flag.
Proposed TODOs:
- [x] rename
self.lsh_attention
toself.use_lsh
to avoid confusion withself.lsh_attn
which is a module - [x] synchronize mask naming across all Attention modules: input_mask->attn_mask; minor renaming in LSH modules to make it consistent with
Attention
- [x] add masking support to ReformerAttentionV2
- [ ] add masking tests
- [ ] synchronize
store_attention
functionality - [x] test switchable attention module with synthetic task
bs = 4
sl = 128
d = 256
x = torch.randn(bs, sl, d)
attn = ReformerAttentionV2(d, use_lsh=True)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
attn.use_lsh = False
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
State dict remanes unchanged
[(k, v.shape) for k, v in attn.state_dict().items()]