bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
shared_proj = SharedQKAttnInProj(d)
q1, k1, v1 = shared_proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
assert q1 is k1
q1.shape, k1.shape, v1.shape
Scaled dot-product attention is calculated as:
$$\textbf {Attention}(Q,K,V) = \textbf {softmax}({QK^T\over\sqrt d_k})V $$
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ScaledDotProdAttention(d, 4)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
attn_func = ScaledDotProdAttention(d, 4, shared_qk=True)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
q = torch.randn(bs, sl, d).cuda()
k = torch.randn(bs, sl, d).cuda()
v = torch.randn(bs, sl, d).cuda()
attn_func = ScaledDotProdAttention(d, 4, shared_qk=True)
out = attn_func(q, k, v)
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out = attn(x, context)
assert (bs, sl, d) == out.size()
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d, shared_qk=True)
out = attn(x)
assert (bs, sl, d) == out.size()
e_msg = "Causal masking error"
attn = Attention(d, causal=True, dropout=0)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2)
# all elements in first half are equal despite second half is defferent
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = Attention(d, causal=False, dropout=0)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2, mask=mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = Attention(d, causal=False, dropout=0)
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
context_mask = torch.ones(bs, sl)
# mask out second half of context
context_mask[:, sl//2:] = 0
context_mask = context_mask.bool()
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x, context2, context_mask=context_mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1, out2), e_msg
# all output values are different for different context
out1 = attn(x, context)
out2 = attn(x, context2)
assert not (out1 == out2).any()
bs = 4
sl = 16
csl = sl + 16
d = 64
x = torch.rand(bs, sl, d)
context = torch.rand(bs, csl, d)
mask = torch.ones(bs, sl)
mask[:, -5:] = 0
context_mask = torch.ones(bs, csl)
context_mask[:, -10:] = 0
mask, context_mask = mask.bool(), context_mask.bool()
attn = Attention(d, store_attention=True)
out = attn(x, context, mask=mask, context_mask=context_mask)
attention = attn.attn.attention
assert (bs, sl, d) == out.size()
assert attention.size() == (bs, attn.attn.n_heads, sl, csl)
# zeros for masked keys and "don't cares" for masked queries
Customized _checkpoint
and _ChunkedAttnCptFunction
to handle non-tensor args. See for source implementation.
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ChunkedDotProdAttention(d, 4, n_chunks=8, causal=True, shared_qk=True)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
cattn = ChunkedAttention(d, n_chunks=10)
out = cattn(x)
assert (bs, sl, d) == out.size()
def time_fwd_bwd(f, x):
loss = f(x).sum()
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AdditiveInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == (bs, x.size(1)+context.size(1), d)
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = AdditiveAttention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out = attn(x, context)
assert (bs, sl, d) == out.size()
e_msg = "Causal masking error"
attn = AdditiveAttention(d, causal=True, dropout=0)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2)
# all elements in first half are equal despite second half is defferent
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = AdditiveAttention(d, causal=False, dropout=0)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2, mask=mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = Attention(d, causal=False, dropout=0)
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
context_mask = torch.ones(bs, sl)
# mask out second half of context
context_mask[:, sl//2:] = 0
context_mask = context_mask.bool()
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x, context2, context_mask=context_mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1, out2), e_msg
# all output values are different for different context
out1 = attn(x, context)
out2 = attn(x, context2)
assert not (out1 == out2).any()
LSH attention from Reformer: The Efficient Transformer. Based on lucidrains/reformer-pytorch, but simpliefied and refactored. Uses shared keys and queries, but requires both to be passed as input (even though they are identical).
Test LSH-attention layer. Note: d_model
is infered from input. Assumes shared key and query, but accepts both as input.
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
shared_proj = SharedQKAttnInProj(d)
q, k, v = shared_proj(x)
lsh_attn = LSHAttention()
out, _, _ = lsh_attn(q, k, v)
assert (bs, sl, d) == out.size()
lsh_attn = LSHAttention(seed=123)
lsh_attn1 = LSHAttention(seed=123)
assert all_equal(lsh_attn(q, k, v), lsh_attn1(q,k,v))
Performs multihead LSHAttention
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
attn = LSHSelfAttention(d, seed=123, out_dropout=0.1, dropout=0.1, dropout_hash=0.1)
assert all_equal(attn(x), attn(x))
out = attn(x)
assert (bs, sl, d) == out.size()
Note that unlike the testing for the standard transformer, we can't draw new vectors for our change input since this would impact the clustering of the vectors in the LSH-algorithm. If we instead scale by a constant factor, the angular based clustering is not affected, even though the values have changed.
e_msg = "Causal masking error"
attn = LSHSelfAttention(d, causal=True, dropout=0, seed=123)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = x2[:, sl//2:, :]*2
out2 = attn(x2)
assert torch.allclose(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = LSHSelfAttention(d, causal=False, dropout=0, seed=123)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = x2[:, sl//2:, :]*2
out2 = attn(x2, mask=mask)
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = LSHSelfAttention(d, causal=False, dropout=0, seed=123)
x = torch.randn(bs, sl, d)
Passing in context=x should not alter the result, as compared to no context:
out0 = attn(x,)
out1 = attn(x, context=x)
assert all_equal(out0, out1)
Mask second half of context
context = x.clone() # cloning x for context
context_mask = torch.ones(bs, sl).bool()
context_mask[:, sl//2:] = False
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, -1:, :] = context2[:, -1:, :]*2 # scaling to not affect clustering, relevant here?
#context2[:, sl//2:, :] = torch.randn(bs, sl//2, d) # new random data
out2 = attn(x, context2, context_mask=context_mask)
out1[0], out2[0]
#assert all_equal(out1, out2), e_msg
out1 = attn(x, context)
out2 = attn(x, context2)
#assert not (out1 == out2).any()
Reformer attention calculates multihead attention with shared keys and queries, and allows switching between full Attention
or LSHAttention
at creation, but not during inference or training.
bs = 4
sl = 128
d = 512
x = torch.randn(bs, sl, d)
attn_lsh = ReformerAttention(d, lsh_attention=True)
out = attn_lsh(x)
assert (bs, sl, d) == out.size()
attn_full = ReformerAttention(d, lsh_attention=False)
out = attn_full(x)
assert (bs, sl, d) == out.size()
The state dicts of full and lsh attention are identical:
[(k, v.shape) for k, v in attn_lsh.state_dict().items()]
[(k, v.shape) for k, v in attn_full.state_dict().items()]
ReformerAttentionV2 containes both LSHAttention
and ScaledDotProdAttention
and which one to use is determined by self.lsh_attention
Proposed TODOs:
- [x] rename
to avoid confusion withself.lsh_attn
which is a module - [x] synchronize mask naming across all Attention modules: input_mask->attn_mask; minor renaming in LSH modules to make it consistent with
- [x] add masking support to ReformerAttentionV2
- [ ] add masking tests
- [ ] synchronize
functionality - [x] test switchable attention module with synthetic task
bs = 4
sl = 128
d = 256
x = torch.randn(bs, sl, d)
attn = ReformerAttentionV2(d, use_lsh=True)
out = attn(x)
assert (bs, sl, d) == out.size()
attn.use_lsh = False
out = attn(x)
assert (bs, sl, d) == out.size()
State dict remanes unchanged
[(k, v.shape) for k, v in attn.state_dict().items()]