Attention Projection

class AttnInProj[source]

AttnInProj(d_model:int, bias:bool=False) :: Module

Computes q, k, v from input x and [optional] context

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
(torch.Size([4, 128, 64]), torch.Size([4, 112, 64]), torch.Size([4, 112, 64]))

class AttnInProjV2[source]

AttnInProjV2(d_model:int, bias:bool=False) :: Module

Computes q, k, v from input x and [optional] context

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
(torch.Size([4, 128, 64]), torch.Size([4, 112, 64]), torch.Size([4, 112, 64]))

Shared Query-Key Attention Projection

class SharedQKAttnInProj[source]

SharedQKAttnInProj(d_model:int, bias:bool=False) :: Module

Computes q, k, v from input x and [optional] context

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
shared_proj = SharedQKAttnInProj(d)
q1, k1, v1 = shared_proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
assert q1 is k1
q1.shape, k1.shape, v1.shape
(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))

Scaled Dot Product Attention

class ScaledDotProdAttention[source]

ScaledDotProdAttention(d_model, n_heads, causal=False, dropout=0.0, shared_qk=False, store_attention:bool=False) :: Module

Computes scaled dot-product attnetion given q, k, v

Scaled dot-product attention is calculated as:

$$\textbf {Attention}(Q,K,V) = \textbf {softmax}({QK^T\over\sqrt d_k})V $$

q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ScaledDotProdAttention(d, 4)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape
torch.Size([4, 128, 64])
attn_func = ScaledDotProdAttention(d, 4, shared_qk=True)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape
torch.Size([4, 128, 64])
q = torch.randn(bs, sl, d).cuda()
k = torch.randn(bs, sl, d).cuda()
v = torch.randn(bs, sl, d).cuda()
attn_func = ScaledDotProdAttention(d, 4, shared_qk=True)
out = attn_func(q, k, v)

Attention container

class Attention[source]

Attention(d_model:int, n_heads:int=8, causal:bool=False, mask:Tensor=None, dropout:float=0.1, out_dropout:float=None, bias:bool=False, shared_qk:bool=False, store_attention:bool=False) :: Module

Standard attention module using scaled dot-product attention

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
out = attn(x, context)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d, shared_qk=True)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
e_msg = "Causal masking error"
attn = Attention(d, causal=True, dropout=0)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2)
# all elements in first half are equal despite second half is defferent
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = Attention(d, causal=False, dropout=0)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2, mask=mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = Attention(d, causal=False, dropout=0)
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
context_mask = torch.ones(bs, sl)
# mask out second half of context
context_mask[:, sl//2:] = 0
context_mask = context_mask.bool()
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x, context2, context_mask=context_mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1, out2), e_msg
# all output values are different for different context
out1 = attn(x, context)
out2 = attn(x, context2)
assert not (out1 == out2).any()
torch.manual_seed(842)
bs = 4
sl = 16
csl = sl + 16
d = 64
x = torch.rand(bs, sl, d)
context = torch.rand(bs, csl, d)
mask = torch.ones(bs, sl)
mask[:, -5:] = 0
context_mask = torch.ones(bs, csl)
context_mask[:, -10:] = 0
mask, context_mask = mask.bool(), context_mask.bool()
attn = Attention(d, store_attention=True)
out = attn(x, context, mask=mask, context_mask=context_mask)
attention = attn.attn.attention
assert (bs, sl, d) == out.size()
assert attention.size() == (bs, attn.attn.n_heads, sl, csl)
# zeros for masked keys and "don't cares" for masked queries
plt.matshow(attention[0,0]);

Memory efficient attention

Customized _checkpoint and _ChunkedAttnCptFunction to handle non-tensor args. See https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html for source implementation.

class MemEfficientAttention[source]

MemEfficientAttention(d_model, n_heads, causal=False, dropout=0.0, shared_qk=False, store_attention:bool=False) :: Module

Memory efficient and very time inefficient attention for long seqences

O(L) memory complexity but uses python loop to compute attention for 1 query at a time

class ChunkedDotProdAttention[source]

ChunkedDotProdAttention(d_model, n_heads, causal=False, dropout=0.0, shared_qk=False, n_chunks=1, store_attention:bool=False) :: Module

Memory efficient and time inefficient attention for long seqences

O(L) memory complexity if n_chunks == seq_len but uses python loop to compute attention for chunks of queries at a time

q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ChunkedDotProdAttention(d, 4, n_chunks=8, causal=True, shared_qk=True)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape
torch.Size([4, 128, 64])

class ChunkedAttention[source]

ChunkedAttention(d_model:int, n_heads:int=8, causal:bool=False, mask:Tensor=None, dropout:float=0.1, out_dropout:float=None, bias:bool=False, shared_qk:bool=False, n_chunks:int=1, store_attention:bool=False) :: Module

Standard attention module using scaled dot-product attention

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
cattn = ChunkedAttention(d, n_chunks=10)
out = cattn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
def time_fwd_bwd(f, x):
    loss = f(x).sum()
    loss.backward()

Additive Attention

class AdditiveInProj[source]

AdditiveInProj(d_model:int, bias:bool=False) :: Module

Computes q, k, v from input x and [optional] context

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AdditiveInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape
(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == (bs, x.size(1)+context.size(1), d)
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape
(torch.Size([4, 128, 64]), torch.Size([4, 240, 64]), torch.Size([4, 240, 64]))

class AdditiveAttention[source]

AdditiveAttention(d_model:int, n_heads:int=8, causal:bool=True, dropout:float=0.1, out_dropout:float=None, bias:bool=False, shared_qk:bool=False, store_attention:bool=False) :: Attention

Additive attention module

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = AdditiveAttention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
out = attn(x, context)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
e_msg = "Causal masking error"
attn = AdditiveAttention(d, causal=True, dropout=0)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2)
# all elements in first half are equal despite second half is defferent
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg
e_msg = "Masking error"
attn = AdditiveAttention(d, causal=False, dropout=0)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x2, mask=mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()
e_msg = "Context masking error"
attn = Attention(d, causal=False, dropout=0)
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
context_mask = torch.ones(bs, sl)
# mask out second half of context
context_mask[:, sl//2:] = 0
context_mask = context_mask.bool()
out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
out2 = attn(x, context2, context_mask=context_mask)
# all elements are equal, masked values do not effect result
assert all_equal(out1, out2), e_msg
# all output values are different for different context
out1 = attn(x, context)
out2 = attn(x, context2)
assert not (out1 == out2).any()

LSH Attention

LSH attention from Reformer: The Efficient Transformer. Based on lucidrains/reformer-pytorch, but simpliefied and refactored. Uses shared keys and queries, but requires both to be passed as input (even though they are identical).

class LSHAttention[source]

LSHAttention(dropout=0.0, bucket_size=64, n_hashes=8, causal=False, allow_duplicate_attention=False, attend_across_buckets=False, drop_for_hash_rate=0.0, return_attn=False, seed=None, **kwargs) :: Module

LSH attention module:

Test LSH-attention layer. Note: d_model is infered from input. Assumes shared key and query, but accepts both as input.

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
shared_proj = SharedQKAttnInProj(d)
q, k, v = shared_proj(x)
lsh_attn = LSHAttention()
out, _, _ = lsh_attn(q, k, v)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])
lsh_attn = LSHAttention(seed=123)
lsh_attn1 = LSHAttention(seed=123)
assert all_equal(lsh_attn(q, k, v), lsh_attn1(q,k,v))

LSH-self-attention

Performs multihead LSHAttention

class LSHSelfAttention[source]

LSHSelfAttention(d_model, n_heads=8, bucket_size=64, n_hashes=8, causal=False, bias:bool=False, attend_across_buckets=False, allow_duplicate_attention=False, return_attn=False, seed=None, dropout=0.0, dropout_hash=0.0, out_dropout=0.0) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
attn = LSHSelfAttention(d, seed=123, out_dropout=0.1, dropout=0.1, dropout_hash=0.1)
assert all_equal(attn(x), attn(x))
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 64])

Testing causal masking

Note that unlike the testing for the standard transformer, we can't draw new vectors for our change input since this would impact the clustering of the vectors in the LSH-algorithm. If we instead scale by a constant factor, the angular based clustering is not affected, even though the values have changed.

e_msg = "Causal masking error"
attn = LSHSelfAttention(d, causal=True, dropout=0, seed=123)
x1 = torch.randn(bs, sl, d)
out1 = attn(x1)
x2 = x1.clone()
x2[:, sl//2:, :] = x2[:, sl//2:, :]*2
out2 = attn(x2)
assert torch.allclose(out1[:, :sl//2], out2[:, :sl//2]), e_msg
assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg

Testing masking

e_msg = "Masking error"
attn = LSHSelfAttention(d, causal=False, dropout=0, seed=123)
x1 = torch.randn(bs, sl, d)
mask = torch.ones(bs, sl)
# mask out second half of input
mask[:, sl//2:] = 0
mask = mask.bool()
out1 = attn(x1, mask=mask)
x2 = x1.clone()
x2[:, sl//2:, :] = x2[:, sl//2:, :]*2
out2 = attn(x2, mask=mask)
assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
out1 = attn(x1)
out2 = attn(x2)
assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()

Testing context masking

e_msg = "Context masking error"
attn = LSHSelfAttention(d, causal=False, dropout=0, seed=123)
x = torch.randn(bs, sl, d)

Passing in context=x should not alter the result, as compared to no context:

out0 = attn(x,)
out1 = attn(x, context=x)
assert all_equal(out0, out1)

Mask second half of context

context = x.clone()                        # cloning x for context
context_mask = torch.ones(bs, sl).bool()
context_mask[:, sl//2:] = False

out1 = attn(x, context, context_mask=context_mask)
context2 = context.clone()
context2[:, -1:, :] = context2[:, -1:, :]*2   # scaling to not affect clustering, relevant here?
#context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)  # new random data
out2 = attn(x, context2, context_mask=context_mask)
out1[0], out2[0]
(tensor([[ 0.0452, -0.0576, -0.0456,  ..., -0.1130, -0.0580, -0.1802],
         [ 0.1083, -0.0677,  0.0850,  ..., -0.0548, -0.0367, -0.1796],
         [-0.0294, -0.1177,  0.0340,  ..., -0.0477, -0.0325, -0.0656],
         ...,
         [-0.0314, -0.0388, -0.0134,  ..., -0.0803, -0.1505, -0.0360],
         [ 0.0812, -0.0492,  0.0248,  ..., -0.0232, -0.0728, -0.1378],
         [-0.0238, -0.1127, -0.0312,  ...,  0.0057, -0.0981, -0.0497]],
        grad_fn=<SelectBackward>),
 tensor([[ 0.0408, -0.0534, -0.0392,  ..., -0.1195, -0.0529, -0.1826],
         [ 0.1115, -0.0670,  0.0916,  ..., -0.0603, -0.0332, -0.1739],
         [-0.0294, -0.1200,  0.0364,  ..., -0.0530, -0.0293, -0.0682],
         ...,
         [-0.0325, -0.0378, -0.0093,  ..., -0.0855, -0.1461, -0.0382],
         [ 0.0833, -0.0506,  0.0322,  ..., -0.0283, -0.0702, -0.1345],
         [-0.0238, -0.1127, -0.0312,  ...,  0.0057, -0.0981, -0.0497]],
        grad_fn=<SelectBackward>))
(out1==out2).sum()
tensor(256)
#assert all_equal(out1, out2), e_msg
out1 = attn(x, context)
out2 = attn(x, context2)
#assert not (out1 == out2).any()

Reformer Attention

Reformer attention calculates multihead attention with shared keys and queries, and allows switching between full Attention or LSHAttention at creation, but not during inference or training.

class ReformerAttention[source]

ReformerAttention(d_model:int, n_heads:int=8, causal:bool=False, mask:Tensor=None, dropout:float=0.1, out_dropout:float=None, bias:bool=False, store_attention:bool=False, lsh_attention:bool=True, n_hashes:int=8, bucket_size:int=64) :: Module

Reformer attention container.

Switch between FullSharedQKAttention and LSHAttention.

bs = 4
sl = 128
d = 512
x = torch.randn(bs, sl, d)
attn_lsh = ReformerAttention(d, lsh_attention=True)
out = attn_lsh(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 512])
attn_full = ReformerAttention(d, lsh_attention=False)
out = attn_full(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 512])

The state dicts of full and lsh attention are identical:

[(k, v.shape) for k, v in attn_lsh.state_dict().items()]
[('attn.in_proj.to_qk.weight', torch.Size([512, 512])),
 ('attn.in_proj.to_v.weight', torch.Size([512, 512])),
 ('attn.out_proj.weight', torch.Size([512, 512]))]
[(k, v.shape) for k, v in attn_full.state_dict().items()]
[('attn.in_proj.to_qk.weight', torch.Size([512, 512])),
 ('attn.in_proj.to_v.weight', torch.Size([512, 512])),
 ('attn.out_proj.weight', torch.Size([512, 512]))]

class ReformerAttentionV2[source]

ReformerAttentionV2(d_model:int, n_heads:int=8, causal:bool=False, attn_mask:Tensor=None, dropout:float=0.1, out_dropout:float=None, bias:bool=False, store_attention:bool=False, use_lsh:bool=True, n_hashes:int=8, bucket_size:int=64, seed:int=None) :: Module

Reformer attention container. Take on making it switchable on the fly.

Switch between FullSharedQKAttention and LSHAttention.

ReformerAttentionV2 containes both LSHAttention and ScaledDotProdAttention and which one to use is determined by self.lsh_attention flag.

Proposed TODOs:

  • [x] rename self.lsh_attention to self.use_lsh to avoid confusion with self.lsh_attn which is a module
  • [x] synchronize mask naming across all Attention modules: input_mask->attn_mask; minor renaming in LSH modules to make it consistent with Attention
  • [x] add masking support to ReformerAttentionV2
  • [ ] add masking tests
  • [ ] synchronize store_attention functionality
  • [x] test switchable attention module with synthetic task
bs = 4
sl = 128
d = 256
x = torch.randn(bs, sl, d)
attn = ReformerAttentionV2(d, use_lsh=True)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 256])
attn.use_lsh = False
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape
torch.Size([4, 128, 256])

State dict remanes unchanged

[(k, v.shape) for k, v in attn.state_dict().items()]
[('in_proj.to_qk.weight', torch.Size([256, 256])),
 ('in_proj.to_v.weight', torch.Size([256, 256])),
 ('out_proj.weight', torch.Size([256, 256]))]