from fastai.vision.all import *
import pdb


LSH is an algorithm for clustering of high dimensional data. There are several ways of implementing the algorithm. We'll look at random projections and random rotations.

## LSH clustering

### Random projections

Yannick explains LSH with random projections, and the same with this blog post. That means that in a 2D case we can envision lines drawn at random, and points grouped depending on if they point in a similar direction or not. This method is probabilistic as points that are close can end up in different buckets by chance, but they will have a high probability of beeing grouped to gether. Illustration from the blog: We'll demonstrate random projections for the 2 dimensional case. First we fix some points in the 2d plane. We'll do everything deterministically to begin with:

points = np.array([[0.9, 1],
[-0.9, -1],
[0.5, -.5]])
fig ,ax = plt.subplots()
for (x, y),c in zip(points, ['r', 'b', 'g']):
ax.scatter(x, y, c=c)
ax.grid(); Next we make a vector assumed to pass thru the origin. We'll make it a positive unit vector so we can think of it pointing up and to the right. We can manually project this vector from say x = [-1, 1] and display it as a line:

u = np.array([1, 1])
ax.plot([-1, 1], [-1, 1])
fig The dot product between vectors a and b is the lenght of a projected onto b multiplied with the length of b. So if two vectors point in the same direction their dot product is positive and vice versa. If vectors are perpendicular the dot product is 0. The doproduct is defined by summing up the pairwise products of two vectors. Note that we can think of our points as vectors from the origin. Taking the dot product of the red point and u, we get as expected a postivite dotproduct:

(points*u).sum()

1.9

We can calculate the dotproduct for all our points by matrix multiplication. The red vector points along u, the blue opposite, and the green on is perpendicular:

points@u.T

array([ 1.9, -1.9,  0. ])

We can generalize this in a function by adding randomness, and only keeping the sign of the dotproduct.

def rand_proj(points):
u = np.random.randn(1,2) # 1 projection in 2-dim
dots = points@u.T
return np.sign(dots)

rand_proj(points)

array([[-1.],
[ 1.],
[ 1.]])

In this case we have two buckets, -1 and 1

We can also repeat the bucketing process n times to get a more stable estimate. Each run will produce a different results. Eg. the first column below represents a new hash bucket for the red point.

[rand_proj(points).squeeze() for _ in range(5)]

[array([ 1., -1.,  1.]),
array([ 1., -1.,  1.]),
array([-1.,  1., -1.]),
array([-1.,  1.,  1.]),
array([ 1., -1.,  1.])]

We can also generalize this to produce more than two buckets. The number of buckets is 2x the number of random projections:

def rand_proj(points, n_projections=2, n_dim=2):
u = np.random.randn(n_projections, n_dim)
dots = points@u.T
return np.sign(dots)

rand_proj(points, n_projections=2)

array([[-1.,  1.],
[ 1., -1.],
[-1.,  1.]])

In this case we have 4 buckets: [-1,-1], [-1,1], [1,-1], [1,1]. Which could be further combined into a single id [0,1,2,3]

### random rotations

The paper instead opts for an angular interpretation of LSH: In the 2d case each point is projected onto a unit sphere, and then rotated randomly. Bucketing will depend on which of the sectors it ends up in. The algorithm is decribed as:

To get b hashes, we first fix a random matrix R of size [dk, b/2]. We then define h(x) = arg max([xR; −xR]) where [u; v] denotes the concatenation of two vectors. This method is a known LSH scheme (Andoni et al., 2015) and is easy to implement and apply to batches of vectors.

The blog has an implementation to compute a single hash wich follows these steps.

def rand_rotations(x, hidden_dim, n_buckets):
random_rotations = np.random.randn(hidden_dim, n_buckets // 2)
rotated_vectors = np.dot(x, random_rotations)
rotated_vectors = np.hstack([rotated_vectors, -rotated_vectors])
return np.argmax(rotated_vectors, axis=-1)

rand_rotations(points, hidden_dim=2, n_buckets=4)

array([3, 1, 2])

This has the nice property of directly giving us the hash-bucket id instead of our list above. The next step will be to scale the algorithm to do several rounds. One could simply loop it, but it will be more effectient to add an extra dimension to our matrices. We will also need to take care of batch and attention head dimensions.

### Incorporation of batches and multiple hashing rounds

from einops import rearrange, repeat, reduce


The code for the LSH algorithm used for the paper can be found in the trax library. Lucidrains also has a stripped down version. We'll base our algorithm on lucidrains, but simplify even further to make the algorithm as clear as possible:

• we'll assume rehashing each round as in trax library
• no dropout on the vectors to be hashed.
• won't pay attention to device at the moment
• assume correct dytpes passed in
• similar number of rotations per head

That means we have to:

1. Keep track of the various dimension
2. perform the random_rotations part of the algorithm (as above)
3. Structure the output depending on number of rounds and buckets
def hash_vectors(vecs, n_buckets=2, n_rounds=1):

# 1. account for the input shapes. vecs = [bs, sl, hidden_dim]
assert n_buckets % 2 == 0
batch_size, _, hidden_dim = vecs.shape
rotations_shape = (hidden_dim, n_rounds, n_buckets // 2)

# 2. get the dotproduct, cat and argmax like in the section above

random_rotations = repeat(torch.randn(rotations_shape),  #repeat rotations accross the batch dimension
'h nr nb -> bs h nr nb', bs=batch_size)

rotated_vecs = torch.einsum('bsh,bhrn->brsn',
vecs,               # [bs, sl, hidden_dim]
random_rotations)   # [bs, hidden_dim, n_rounds, n_buckets//2]
# rotated vecs: [bs, n_rounds, sl, n_buckets//2]

rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) # [bs, n_rounds, sl, n_buckets]
buckets = torch.argmax(rotated_vecs, dim=-1)                    # [bs, n_rounds, sl]

# 3. Next we add offsets so that bucket numbers from different hashing rounds don't overlap.

offsets = torch.arange(n_rounds)                               # list of [0,1,2,..n_rounds-1]
offsets = rearrange(offsets * n_buckets, '(r)-> (1)(r)(1)')    # [1, n_rounds, 1]
buckets = rearrange(buckets+offsets, 'bs r sl -> bs (r sl)')   # [bs, (n_rounds*sl)]
return buckets


Let's pass our trusty old points in, but first convert to tensors, and create a batch dimension first.

t = torch.tensor(points, dtype=torch.float32)
t = rearrange(t, 'b d -> b () d')
hash_vectors(t, n_buckets=2), hash_vectors(t, n_buckets=2, n_rounds=5)

(tensor([,
,
]),
tensor([[1, 2, 4, 7, 8],
[0, 3, 5, 6, 9],
[1, 2, 4, 6, 9]]))

In the multiround case, the result from each hashing round is stacked along the 1 dimension, and an offset is added so each one has a unique index, from [0 to n_rounds * n_hases -1], [0-9] in this case.

In the transformer setting the q and k matrix shapes will be: [bs, sl, hidden_dim]

t = torch.randn(64, 512, 128)
out = hash_vectors(t, n_buckets=4, n_rounds=1)
out.shape, out.min(), out.max()

(torch.Size([64, 512]), tensor(0), tensor(3))
out = hash_vectors(t, n_buckets=4, n_rounds=3)
out.shape, out.min(), out.max()

(torch.Size([64, 1536]), tensor(0), tensor(11))

## Main steps of LSH attention

The next parts we need to add are: 1. sort first for bucket id, next for position in original sequence
2. chunk buckets to some given size
3. concatenate chunk with previous chunk to allow for inbetween attention
4. calculate attention and output per (concatenated) chunk
5. unsort everything and return

Let's set up some data to test - with a bs of 1 for sake of simplicity

k = torch.randn(1, 512, 128)
buckets = hash_vectors(k, n_buckets=4, n_rounds=1)
buckets[0,:10]

tensor([1, 1, 1, 2, 1, 3, 2, 1, 3, 0])

The number of vectors in each bucket:

torch.unique(buckets, return_counts=True)

(tensor([0, 1, 2, 3]), tensor([139, 122, 127, 124]))

We have 512 vectors of dim 128 to sort according to bucket:

k[0,:,:].shape

torch.Size([512, 128])

### Sorting tensors

tmp = torch.arange(5*3).reshape((5,3))
tmp

tensor([[ 0,  1,  2],
[ 3,  4,  5],
[ 6,  7,  8],
[ 9, 10, 11],
[12, 13, 14]])

We can sort along a specified axis, and get the index of the sorted item in the original tensor

v, i = tmp.sort(dim=1, descending=True)
v

tensor([[ 2,  1,  0],
[ 5,  4,  3],
[ 8,  7,  6],
[11, 10,  9],
[14, 13, 12]])
i

tensor([[2, 1, 0],
[2, 1, 0],
[2, 1, 0],
[2, 1, 0],
[2, 1, 0]])

And use the indices to get our original order back:

i[:,i]

tensor([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])

But how do we sort by some other index? We can slice according to the indices of the sorted buckets:

reorder = [2,0,1]
tmp[:,reorder]

tensor([[ 2,  0,  1],
[ 5,  3,  4],
[ 8,  6,  7],
[11,  9, 10],
[14, 12, 13]])

Let's get indexes of the sorted buckets tensor and use that to reorder k:

v, i = buckets.sort(dim=-1)
v[:,:5], v[:,-5:], i[:,:5]

(tensor([[0, 0, 0, 0, 0]]),
tensor([[3, 3, 3, 3, 3]]),
tensor([[351, 366, 364, 114, 363]]))

And finally reorder our k. Since we have a batch dimension (0) we have sorted along dim(1) - the rows.

k[:,i,:]

tensor([[[-0.5602, -0.2878, -1.1196,  ..., -1.2602,  1.1406, -0.4469],
[ 0.5765,  0.1284,  0.0911,  ..., -0.2538, -0.7697, -1.6130],
[ 0.2006, -0.7220, -1.3837,  ..., -1.2253,  0.5768,  0.6116],
...,
[-0.0983,  0.5041, -1.4759,  ..., -0.6982, -0.0050,  1.1906],
[-1.5745,  0.4858, -1.5506,  ..., -0.8055, -0.6155, -2.4515],
[-0.4779, -1.0148,  1.4694,  ..., -1.2689,  0.2453, -0.0928]]])

We can verify that the first item above is the same as the index i[:,0] of our k:

k[:,i[:,0],:10]

tensor([[[-0.5602, -0.2878, -1.1196, -0.1774, -0.3821, -0.5390, -0.8756,
-0.2887, -1.0551, -0.4596]]])

We also have to maintain order with respect to the orginial sequence order. So every k in bucket 0 have to be resorted according to it's original place in the sequenze.

### Sort chunks

Chunking to equal size is straight forward. Remove the batch dimension and chunk along the sequence dimension:

chunks = k.squeeze().chunk(4, dim=0)
k.shape, len(chunks), chunks.shape

(torch.Size([1, 512, 128]), 4, torch.Size([128, 128]))

### calculate attention

We can get the attention from a chunk (assuming k=v) the normal way, ignoring attention heads and batch dimension for the moment. We also must make sure that a chunk can attend to the prvious chunk. For chunk 1 this means concatenation with chunk 0, before attention is calculated:

k0, k1, *_ = chunks
k0.shape, k1.shape

(torch.Size([128, 128]), torch.Size([128, 128]))
k_01 = torch.cat((k0,k1),0)
k_01.shape

torch.Size([256, 128])
attn = (k_01@k_01.T).softmax(-1)
attn.shape

torch.Size([256, 256])
attn[0,:].sum()

tensor(1.)

### caveats

There are some important steps we still need to account for. From the paper:  That means we still have to:

• set h(kj) = h(qj)
• vectors in each chunk have to be sorted according to order in the original sequence
• mask out attention to not attend to it's own position (except when no other targets exist)

## Implement LSH-attention layer

We have a rough idea of the implementation steps, so let's see if we can find a minimal solution based on the implementation by lucidrains. Once again will try to strip away as much as possible to leave a minimal solution for clarity:

• no dropout
• don't return attn matrix
• don't detach tensors where we don't need gradients

### helpers

# boundaries might occur in the middle of a sequence of items from the
# same bucket, so this increases the chances of attending to relevant items.
def look_one_back(x):
x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)

def sort_key_val(t1, t2, dim=-1):
values, indices = t1.sort(dim=dim)
t2 = t2.expand_as(t1)
return values, t2.gather(dim, indices)

def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

def chunked_sum(tensor, chunks=1):
*orig_size, last_dim = tensor.shape
tensor = tensor.reshape(-1, last_dim)
summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]

TOKEN_SELF_ATTN_VALUE = -5e4 # carefully set for half precision to work


### Forward, step by step

Let's step through the forward() of the LSHAttention layer (below), to see if we can makes sense of it. We start with random test data. We assume q=k, and pass along v as well:

bs, sl, dim = 64, 512, 256
qk, v = torch.randn((bs, sl, dim)), torch.randn((bs, sl, dim))
qk.shape, v.shape

(torch.Size([64, 512, 256]), torch.Size([64, 512, 256]))

#### hashing

Grouping is done by hash_vector(). We'll test 6 rounds with 16 buckets.

n_buckets, n_rounds = 16, 6
buckets = hash_vectors(qk, n_buckets, n_rounds)
buckets.shape

torch.Size([64, 3072])
buckets[0,:20]

tensor([ 5,  0,  5, 14,  2,  4,  0,  1,  5,  3, 14,  9, 10,  0,  3,  4,  0, 15,
12,  6])

Note that our bucket ids are unique within their hash group. I.e. the first and second hash group has non overlaping bucket ids:

buckets[0, 0:sl].unique(), buckets[0, sl:2*sl].unique()

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]),
tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]))

The ticker is an index into the original sequence id:

ticker = torch.arange(n_rounds * sl).unsqueeze(0).expand_as(buckets)
ticker

tensor([[   0,    1,    2,  ..., 3069, 3070, 3071],
[   0,    1,    2,  ..., 3069, 3070, 3071],
[   0,    1,    2,  ..., 3069, 3070, 3071],
...,
[   0,    1,    2,  ..., 3069, 3070, 3071],
[   0,    1,    2,  ..., 3069, 3070, 3071],
[   0,    1,    2,  ..., 3069, 3070, 3071]])

Note that the id is not reset between hash groups:

ticker[0, sl-5:sl+5]

tensor([507, 508, 509, 510, 511, 512, 513, 514, 515, 516])

But by taking the mod of sl, we can get the sequence id in each hash round:

(ticker % sl)[0, sl-5:sl+5]

tensor([507, 508, 509, 510, 511,   0,   1,   2,   3,   4])

Or the id of the hash round by integer division with sl

(ticker//sl)[0, sl-5:sl+5]

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

#### Sorting

In the lsh layer we make a buckets_and_t index. It's the bucket id scaled by sl + ticker % sl. Let's inspect the first component:

buckets[0,:10]

tensor([ 5,  0,  5, 14,  2,  4,  0,  1,  5,  3])

We scale it by muliplying with sl:

sl*buckets[0,:10]

tensor([2560,    0, 2560, 7168, 1024, 2048,    0,  512, 2560, 1536])

Then we add ticker%sl which is the the sequence id (see above)

# we add the bucket id scaled by seqlen
# shape: [bs, (seqlen*buckets)]
# let us sort according to bucket id and index in sequence
buckets_and_t = sl * buckets + (ticker % sl)
buckets_and_t

tensor([[ 2560,     1,  2562,  ..., 41469, 48638, 45567],
[ 7680,  4097,  4098,  ..., 43517, 41982, 47103],
[ 1024,  4097,  7170,  ..., 45565, 48638, 41983],
...,
[ 4096,  6657,  3586,  ..., 43005, 47102, 44031],
[ 7168,  4609,   514,  ..., 44029, 44542, 44031],
[ 1536,  7681,  3586,  ..., 44541, 43518, 42495]])

Buckets_and_t is an index that takes both hash group id and sequence id into account, on the form bucket_id*sl + sequence id. The first item can take one of 16 values (n_buckets):

idxs = [idx*sl + 0 for idx in range(n_buckets)]
print(idxs)

[0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680]


For the second position we add 1 to offset for the sequence id:

print([idx+1 for idx in idxs])

[1, 513, 1025, 1537, 2049, 2561, 3073, 3585, 4097, 4609, 5121, 5633, 6145, 6657, 7169, 7681]


And we can recreate the original ids from buckets_and_t:

(buckets_and_t[0,:20] - ticker[0,:20])/sl

tensor([ 5.,  0.,  5., 14.,  2.,  4.,  0.,  1.,  5.,  3., 14.,  9., 10.,  0.,
3.,  4.,  0., 15., 12.,  6.])
buckets[0,:20]

tensor([ 5,  0,  5, 14,  2,  4,  0,  1,  5,  3, 14,  9, 10,  0,  3,  4,  0, 15,
12,  6])

Buckets_and_t is a unique index:

len(buckets_and_t[0,:]), len(buckets_and_t[0,:].unique())

(3072, 3072)

It's also non overlapping between hash groups - the max id of the first group is smaller than the min id of the second group:

buckets_and_t[0,:sl].max(), buckets_and_t[0,sl:sl*2].min()

(tensor(8189), tensor(8202))

Next we sort buckets_and_t. Since we scaled the id's, hash groups are autmatically sorted.

sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)     # shapes are [bs, seqlen*n_hashes]
sbuckets_and_t[0,:20]

tensor([  1,   6,  13,  16,  20,  34,  42,  45,  51,  52,  95,  98, 100, 107,
112, 127, 158, 192, 215, 219])

sticker are the ids from our original buckets_and_t. That is, when we unsort sbuckets, sticker gives us the index of where the items belong in buckets_and_t.

sticker[0,sl:sl+10]

tensor([522, 524, 531, 533, 540, 541, 567, 579, 596, 597])

We can use sticker to look up items in buckets_and_t to create sbuckets_and_t, but not the other way around!

buckets_and_t[0,sticker[0,-10:]], sbuckets_and_t[0,-10:]

(tensor([48964, 48969, 48987, 48999, 49011, 49028, 49034, 49068, 49089, 49107]),
tensor([48964, 48969, 48987, 48999, 49011, 49028, 49034, 49068, 49089, 49107]))

But we must also be able to undo this sorting. By doing a normal sort on the sticker list, we get a set of ids that tells us where an item belong in the original buckets_and_t:

_, undo_sort = sticker.sort(dim=-1)                                       # indexes to undo sortings


We can use the gather function to recreate buckets_and_t from the sorted buckets. And we already know we can recreate seqlen id and hash group id from buckets_and_t

sbuckets_and_t.gather(-1, undo_sort)

tensor([[ 2560,     1,  2562,  ..., 41469, 48638, 45567],
[ 7680,  4097,  4098,  ..., 43517, 41982, 47103],
[ 1024,  4097,  7170,  ..., 45565, 48638, 41983],
...,
[ 4096,  6657,  3586,  ..., 43005, 47102, 44031],
[ 7168,  4609,   514,  ..., 44029, 44542, 44031],
[ 1536,  7681,  3586,  ..., 44541, 43518, 42495]])
(sbuckets_and_t.gather(-1, undo_sort)[0,0:20] - ticker[0,:20])/sl

tensor([ 5.,  0.,  5., 14.,  2.,  4.,  0.,  1.,  5.,  3., 14.,  9., 10.,  0.,
3.,  4.,  0., 15., 12.,  6.])
buckets[0,:20]

tensor([ 5,  0,  5, 14,  2,  4,  0,  1,  5,  3, 14,  9, 10,  0,  3,  4,  0, 15,
12,  6])

To recap:

• sbuckets_and_t is a "double index" sorted by both hash group id and sequence id
• we can use undo_sort to recreate buckets_and_t, which is an index of hash_group and sequence id in the original order
• we can use buckets_and_t to get the original bucket ids in their respective hash groups, i.e. buckets

#### Chunking

The next step is to extract the relevant vectors from our input. First we take the sticker mod with sl to produce the sequence id for each hash group. This gives us the id of qk and v to look up:

st = (sticker % sl)              # index of [0..seqlen-1] for each hash round (n_hashes)[bs, seqlen*n_hashes]
st

tensor([[  1,   6,  13,  ..., 428, 449, 467],
[ 12,  20,  27,  ..., 483, 491, 506],
[ 24,  26,  29,  ..., 409, 423, 468],
...,
[  6,  21,  31,  ..., 459, 468, 469],
[  9,  13,  80,  ..., 433, 457, 500],
[ 13,  25,  36,  ..., 448, 474, 482]])

We then lookup the vectors according to their (sorted) seqlen id:

sqk = batched_index_select(qk, st)   # get the sorted qk, [bs, seqlen, model_dim]
sv = batched_index_select(v, st)     # get the sorted v, [bs, seqlen, model_dim]
sqk.shape, sv.shape

(torch.Size([64, 3072, 256]), torch.Size([64, 3072, 256]))

sqk and sv now contains the input vectors sorted by hash group and seqlen id. Next we calculate the number of chunks. We assume chunks of even size, regardless of how many vectors are actually in a particular group:

n_chunks = n_rounds * n_buckets
n_chunks

96

Each chunk has a size along the sl dimension (chunk_size) of:

sqk.shape/n_chunks, sl/n_buckets

(32.0, 32.0)

Next we reshape both the st index, sqk and sv to add the n_chunks dimension. We operate with one index for the query, and a another (duplicate) for k and v:

# get the qk and v chunks and also the indexes to undo sort later
bq_t = bkv_t = torch.reshape(st, (bs, n_chunks, -1))   # [bs, n_chunks, chunk_size]
bq_t.shape

torch.Size([64, 96, 32])
bqk = torch.reshape(sqk, (bs, n_chunks, -1, dim))      # [bs, n_chunks, chunk_size, model_dim]
bv = torch.reshape(sv, (bs, n_chunks, -1, dim))        # [bs, n_chunks, chunk_size, model_dim]
bqk.shape, bv.shape

(torch.Size([64, 96, 32, 256]), torch.Size([64, 96, 32, 256]))

The next step it to split k and v. We normalize k but not q. This is mentioned in the paper: bq = bqk
bk = F.normalize(bqk, p=2, dim=-1).type_as(bq)
bq[..., :].mean(), bk[...,:].mean()

(tensor(0.0001), tensor(8.9945e-06))

Next we add the previous chunk as described in the paper, but only for k and v:

bk = look_one_back(bk)            # [bs, n_chunks, chunk_size*2, model_dim]
bv = look_one_back(bv)            # [bs, n_chunks, chunk_size*2, model_dim]
bkv_t = look_one_back(bkv_t)      # [bs, n_chunks, chunk_size*2, model_dim]


Note that bq and bk now have different shapes:

bk.shape, bq.shape

(torch.Size([64, 96, 64, 256]), torch.Size([64, 96, 32, 256]))

#### dot product attention

dots = torch.einsum('bnsd,bnzd->bnsz',
bq,                  # [bs, n_chunks, chunk_size, model_dim]
bk                   # [bs, n_chunks, chunk_size*2, model_dim]
) * (dim ** -0.5)     # dots: [bs, n_chunks, chunk_size, chunk_size*2]
dots.shape

torch.Size([64, 96, 32, 64]) self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
self_mask.shape      # [bs, n_chunks, chunk_size, chunk_size*2]

torch.Size([64, 96, 32, 64])

We achieve this by comparing the sorted and chunked buckets index with each other for k and q respectively. If they are equal we mask them.

bq_t[:, :, :, None][0,0,:10,0], bkv_t[:, :, None, :][0,0,0,:10]

(tensor([ 1,  6, 13, 16, 20, 34, 42, 45, 51, 52]),
tensor([ 1,  6, 13, 16, 20, 34, 42, 45, 51, 52]))
self_mask[0,0,:5,:5]

tensor([[ True, False, False, False, False],
[False,  True, False, False, False],
[False, False,  True, False, False],
[False, False, False,  True, False],
[False, False, False, False,  True]])
self_mask[0,0,...].shape, self_mask[0,0,...].sum()

(torch.Size([32, 64]), tensor(34))
dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE)


#### Avoid double counting of query-key pairs

According to the appendix of the paper we need to avoid double counting q/k pairs over multiple rounds of hashing. The lucidrains implementation mentions two strategies to deal with this, where only the first one is implemented, and seems to align with the paper:
The default is to count how many times a query-key pair is repeated, and to lower its log-prob correspondingly at each repetition

Note that this implementation does not fold this term into the mask (like the paper), but calculates it separately.

The goal is to consider each attention chunk (i.e. our input sorted by hash bucket id, and sequence id, and split in to n_chunks). For each chunk we have to assess:did a one particular key/query pair end up in similar attention chunks in a later hash round. If so we can argue that the model might "over focus" on this particular q/k pair, and that regularizing it (penalizing) probably makes sense.

First of all, we will be subracting the counts from dots, so they should have a similar shape in the end. We also expect the max count to be 6 (n_rounds).

dots.shape # [bs, n_chunk, chunk_size, chunk_size*2]

torch.Size([64, 96, 32, 64])

Consider item [0,0] in the first attention chunk (in eg. the 0th sample of the batch dimension). It ended up here because of the sorting. I.e. it must have had bucket number 0, and a pretty low sequence id (let's call it x_id) to end up as the first item (in this particular round). But in the next hash rounds it will get a different bucket id, and thus a different location in the sorting for that particular round. We have to associate the item from the input sequence,x_id, with our item[0,0] from the first hash round, and track where this item (x_id) in the original input sequence ended up in later hash rounds.

Also note that it dosen't have to be identical attention chunks, just similar. So if [k32, q34] end up in the attention chunk 44 in the first round, 88 in the second, but 44 and 33 in the third, our count will be 2.

First we need to create an index which tells us at which attention chunk a particular item in the input end up in in each hash round. Note that undo_sort is the key to take our sorted vectors back to original order. Eg. undo_sort[0,0] tells us where in the original sequence the item that ended up as item 0 after sorting belongs:

undo_sort[0,:10], undo_sort[0,sl:sl+10]

(tensor([174,   0, 175, 447,  69, 148,   1,  38, 176, 110]),
tensor([974, 775, 655, 710, 550, 776, 737, 777, 656, 657]))

undo_sort tells us the sequence id where the sorted items will move to. Consider item 50 in our sorted sequences. It comes from various spots in the original input sequence in each of the 6 hash rounds, depending on which bucket it happend to fall into:

[undo_sort[0,s_id + 50].item() for s_id in range(0,sl*6, sl)]

[366, 940, 1368, 1631, 2483, 2635]

If we take the mod of sl, we get the item's location within the original sequence at each hash round. And after sorting they end up at index 50.

[undo_sort[0,s_id + 50].item()%sl for s_id in range(0,sl*6, sl)]

[366, 428, 344, 95, 435, 75]

If we instead take the mod of chunk_size (32 in this case), we get an index of which of the 96 chunks (n_rounds*n_buckets) the items ends up in. That means that the 50th element in the sorted sequences will end up in the following attention chunks:

[undo_sort[0,s_id + 50].item()//32 for s_id in range(0,sl*6, sl)]

[11, 29, 42, 50, 77, 82]

We can do this for all items, and thus create a locs1 id that gives us an item's attention chunk id.

locs1 = undo_sort // bq_t.shape[-1]  # same as chunk size
locs1.shape, locs1.min(), locs1.max()

(torch.Size([64, 3072]), tensor(0), tensor(95))

Eg. items 0-5 in the sorted sequences end up in various attention chunks depending on which hash round we are in:

[locs1[0,slen:slen+5] for slen in range(0,sl*n_rounds, sl)]

[tensor([ 5,  0,  5, 13,  2]),
tensor([30, 24, 20, 22, 17]),
tensor([44, 42, 32, 36, 40]),
tensor([49, 55, 50, 59, 53]),
tensor([72, 74, 77, 78, 68]),
tensor([85, 92, 88, 83, 94])]

We divided attention in even chunks, so bucket_id and attention chunk id won't always match:

buckets[0,sl:sl+5], locs1[0,sl:sl+5]

(tensor([30, 24, 20, 22, 17]), tensor([30, 24, 20, 22, 17]))

Next we create an id that is offset by one, except when it overflows our maximum id of 95, by taking the mod. We need this since our keys will be twice the size of the queries, and come from the neighbouring chunk.

95%n_chunks, 96%n_chunks, 97%n_chunks

(95, 0, 1)
locs2 = (locs1 + 1) % n_chunks
locs2.shape

torch.Size([64, 3072])

This means that the combination of locs1 and 2 gives us ids to neighbour chunks.

locs1[0,-15:], locs2[0,-15:]

(tensor([88, 89, 91, 89, 92, 89, 87, 82, 90, 87, 82, 94, 80, 94, 88]),
tensor([89, 90, 92, 90, 93, 90, 88, 83, 91, 88, 83, 95, 81, 95, 89]))

Next we reshape locs1 and locs2 to [bs, n_rounds, sl], concatenate them along the n_rounds dim, and switches the n_rounds and sl axis.

#     locs1 = buckets * chunk_size + locs1
#     locs2 = buckets * chunk_size + locs2

locs = torch.cat([
torch.reshape(locs1, (bs, n_rounds, sl)),
torch.reshape(locs2, (bs, n_rounds, sl)),
], 1).permute((0, 2, 1))
locs.shape           # [bs, sl, n_rounds*2]

torch.Size([64, 512, 12])

This means that the item that happende to be first in the sorted sequences ends up in various chunks depending on which hash bucket it ends up in in each round since this will affect it's sorting. The last half of the list is the neighbour chunks to our item:

locs[0, 0,:]

tensor([ 5, 30, 44, 49, 72, 85,  6, 31, 45, 50, 73, 86])

This is just a reshaped version of locs1:

[locs1[0,slen] for slen in range(0,sl*n_rounds, sl)]

[tensor(5), tensor(30), tensor(44), tensor(49), tensor(72), tensor(85)]

But we have to get the attention chunk id for the original sequence order We can achieve this by looking up the locs according to st. st is our chunking key, and we used it to reorder our input kq and v into the sorted order. We can now reorder locs in the same way:

slocs = batched_index_select(locs, st)
slocs.shape    # [bs, sl*n_rounds, n_rounds*2]

torch.Size([64, 3072, 12])

We now have the original order. This means that the first item in the original input sequence ends up in the following chunks (+ neighbours).

slocs[0,0,:]

tensor([ 0, 24, 42, 55, 74, 92,  1, 25, 43, 56, 75, 93])

We reshape slocs to include the n_chunks dimension:

b_locs = torch.reshape(slocs, (bs, n_chunks, -1, 2 * n_rounds))
b_locs.shape          # [bs, n_chunks, chunk_size, n_round * 2]

torch.Size([64, 96, 32, 12])

We read this as the 0th element in the batch, the 0th attention block, and the 0th element. Which attention blocks did that particular element end up in in subsequent hash rounds (+ neighbours):

b_locs[0,0,0,:]

tensor([ 0, 24, 42, 55, 74, 92,  1, 25, 43, 56, 75, 93])

For the first element in chunck id 55, we get the following. Note that chunk 55 comes from the 3rd hash round (55//n_buckets).

b_locs[0,55,0,:]

tensor([ 3, 24, 32, 55, 77, 90,  4, 25, 33, 56, 78, 91])

Next we slice the first half of b_locs, and add unit axis at -2. This is for later comparison purposes.

b_locs1 = b_locs[:, :, :, None, :n_rounds]
b_locs1.shape     # [bs, n_chunks, chunk_size, 1, n_rounds]

torch.Size([64, 96, 32, 1, 6])

We copy b_locs1 along the 1 axis. This will be our query indexes Remember that we have 2x keys compared to queries.

bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, n_rounds))
bq_locs.shape    # [bs, n_chunks, chunk_size, 2, n_rounds ]

torch.Size([64, 96, 32, 2, 6])
(bq_locs[:,:,:,0,:]==bq_locs[:,:,:,1,:]).unique()

tensor([True])

bq_locs is now "doubled" in size, so it should be suitable to compare to the keys chunk locations:

bq_locs[0,0,0,:,:]

tensor([[ 0, 24, 42, 55, 74, 92],
[ 0, 24, 42, 55, 74, 92]])

Next, we reshape to [bs, n_chunks, chunk_size, n_rounds*2]

bq_locs = torch.reshape(bq_locs, b_locs.shape)
bq_locs.shape # [bs, n_chunks, chunk_size, n_rounds*2]

torch.Size([64, 96, 32, 12])

This means that bq_locs is similar to our b_locs except for the final half:

bq_locs[0,0,0,:], b_locs[0,0,0,:]

(tensor([ 0, 24, 42, 55, 74, 92,  0, 24, 42, 55, 74, 92]),
tensor([ 0, 24, 42, 55, 74, 92,  1, 25, 43, 56, 75, 93]))

To recap:bq_locs let's us inspect each attention chunk to see which attention chunk one particular query ended up in at later hash rounds:

Next we need our k (similar to v) chunk ids. We use our b_locks and append the previous chunk, similar to the attention calculation.

bkv_locs = look_one_back(b_locs)
bkv_locs.shape      # [bs, n_chunks, chunk_size*2, n_rounds]

torch.Size([64, 96, 64, 12])

Note that we appended the chunk along the -2 dimension:

b_locs.shape

torch.Size([64, 96, 32, 12])

Our key locs are offset by one in the second half to reflect the neighbour chunk:

bkv_locs[0,0,0,:]

tensor([ 0, 24, 42, 55, 74, 92,  1, 25, 43, 56, 75, 93])

The final step is to count elements that have ended up in the same chunk, and thus can attend to each other. Let's try for the first chunk:

bq_locs.shape, bkv_locs.shape

(torch.Size([64, 96, 32, 12]), torch.Size([64, 96, 64, 12]))

In attention chunk 0, key 3 and query 24 has landed in the following chunks over the rounds, with 1 match:

bkv_locs[0,0,3,:], bq_locs[0,0,24,:], (bkv_locs[0,0,3,:]==bq_locs[0,0,24,:]).sum()

(tensor([ 0, 17, 35, 48, 64, 80,  1, 18, 36, 49, 65, 81]),
tensor([ 0, 23, 38, 56, 67, 80,  0, 23, 38, 56, 67, 80]),
tensor(2))

But by adding appropriate unit axis, we can make sure that the sl dimensions don't overlap and broadcast's along each other:

(bq_locs[0, 0, :, None, :].shape,
bkv_locs[0, 0, None, :, :].shape,
(bq_locs[0, 0, :, None, :] == bkv_locs[0, 0, None, :, :]).shape)

(torch.Size([32, 1, 12]), torch.Size([1, 64, 12]), torch.Size([32, 64, 12]))

We compare all elements:

dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :])
dup_counts.shape  # [bs, n_chunks, chunk_size, chunk_size*2, n_rounds]

torch.Size([64, 96, 32, 64, 12])
dup_counts[0,0,3,24,:]

tensor([ True, False, False, False, False,  True, False, False, False, False,
False, False])

And sum across the last dimension:

tmp = dup_counts.sum(-1)
tmp.shape

torch.Size([64, 96, 32, 64])

For memory considerations the summation is chunked

dup_counts = chunked_sum(dup_counts, chunks=(n_rounds * bs))
dup_counts.shape # [bs, n_chunks, chunk_size, chunk_size * 2]

torch.Size([64, 96, 32, 64])

But it's the same as summing along -1:

torch.all(tmp==dup_counts)

tensor(True)
dup_counts[0,0,:10,:10]

tensor([[6, 3, 2, 1, 1, 2, 1, 1, 1, 1],
[1, 6, 3, 3, 1, 1, 1, 2, 1, 2],
[2, 3, 6, 3, 1, 1, 1, 2, 3, 1],
[1, 2, 3, 6, 2, 1, 1, 1, 2, 1],
[1, 1, 1, 2, 6, 1, 1, 1, 2, 1],
[2, 2, 1, 1, 1, 6, 3, 2, 1, 3],
[1, 1, 1, 1, 1, 3, 6, 2, 1, 3],
[1, 2, 2, 1, 1, 2, 2, 6, 2, 4],
[1, 1, 3, 2, 1, 2, 1, 2, 6, 1],
[2, 2, 1, 2, 1, 1, 2, 3, 1, 6]])

Notice that the diagonal in the first block equals number of hash rounds. This makes sense sinse i,i pairs are indentical and always come in the same chunk - we will mask them later. But this is not the case of the other half of the keys (from neighbour chunk):

dup_counts[0,0,:10,32:42]

tensor([[1, 2, 0, 0, 0, 2, 0, 2, 0, 0],
[2, 0, 1, 0, 0, 1, 0, 1, 0, 0],
[1, 1, 0, 0, 0, 1, 1, 2, 0, 0],
[1, 0, 1, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 2, 0, 0, 0, 2, 2],
[0, 0, 0, 1, 0, 1, 0, 1, 1, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 1, 0],
[1, 0, 1, 1, 0, 1, 1, 2, 0, 0],
[1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])

Dup_counts and dots should now be of same shapes:

assert dup_counts.shape == dots.shape
dots.shape, dup_counts.shape

(torch.Size([64, 96, 32, 64]), torch.Size([64, 96, 32, 64]))

Finally we extract the log values of our duplicate counts from the self attention value dots as described in the paper's appendix:

dots = dots - torch.log(dup_counts + 1e-9)

del dup_counts


masked_value = max_neg_value(dots)


We can pass along pad masks if our input is of variable length and padded. We don't wan't the model to see the pad token. Since our original input has been sorted and chunked, we have to sorte the masks similar:

input_mask = torch.ones((bs, sl)).bool()

torch.Size([64, 512])

In case the input_mask is is shorter than input, this pads to appropriate shape with the value True:

#if input_mask is not None:

torch.Size([64, 512])

We reorder the sequence with st and reshapes to include the n_chunks dimension. This gives us the mask for q

mq = input_mask.gather(1, st).reshape((bs, n_chunks, -1))
mq.shape

torch.Size([64, 96, 32])

The mask for k is similar, but includes the neighbour chunk as usual.

mkv = look_one_back(mq)
mkv.shape

torch.Size([64, 96, 64])

We add unit axis for comparision across the final dimensions of q and k:

mask = mq[:, :, :, None] * mkv[:, :, None, :]

torch.Size([64, 96, 32, 64])
dots.masked_fill_(~mask, masked_value)
dots

tensor([[[[-5.0002e+04, -1.0404e+00, -6.2487e-01,  ...,  1.0480e-02,
-5.2455e-02,  2.0620e+01],
[ 5.5138e-02, -5.0002e+04, -1.2008e+00,  ...,  2.0702e+01,
2.0725e+01, -7.3569e-01],
[-6.2263e-01, -1.2101e+00, -5.0002e+04,  ...,  2.0632e+01,
2.0816e+01,  2.0912e+01],
...,
[-3.6686e-02, -6.8328e-01, -9.6327e-01,  ...,  2.0708e+01,
2.0822e+01,  2.0851e+01],
[-6.7548e-01,  4.3852e-02, -7.3175e-01,  ...,  2.0624e+01,
2.0718e+01,  2.0613e+01],
[-6.8309e-01,  4.6035e-02, -3.3323e-02,  ..., -9.9436e-01,
1.6633e-01,  2.0727e+01]],

[[-5.0002e+04,  7.7878e-02, -6.6826e-01,  ..., -5.6195e-01,
-2.5396e-02, -3.7712e-02],
[ 7.5854e-02, -5.0002e+04,  5.5370e-02,  ...,  7.3007e-02,
5.5364e-02, -1.1282e+00],
[ 2.6058e-02,  5.9516e-02, -5.0002e+04,  ...,  7.8917e-02,
-1.1041e-01, -2.5485e-02],
...,
[-7.1051e-01,  2.7179e-02, -7.4234e-01,  ...,  1.7444e-02,
-7.4681e-01,  8.4078e-02],
[ 3.5303e-02,  3.0089e-02,  1.8635e-02,  ...,  3.3148e-02,
-1.0739e+00,  4.6189e-02],
[-7.5039e-01, -6.1051e-01,  6.3676e-02,  ..., -6.9512e-02,
2.4852e-02,  3.3326e-02]],

[[-5.0002e+04,  1.1893e-01, -3.1758e-02,  ..., -6.8121e-02,
-6.4720e-01, -8.7898e-02],
[ 1.2217e-01, -5.0002e+04, -1.3385e-02,  ...,  9.2867e-02,
1.5681e-01, -6.3884e-02],
[-3.1093e-02, -1.2758e-02, -5.0002e+04,  ...,  1.1031e-01,
-7.0500e-01, -6.7539e-01],
...,
[-6.4484e-02, -1.0909e+00,  7.4930e-03,  ...,  3.0774e-02,
6.0454e-02, -1.0151e+00],
[ 4.0968e-02,  3.0282e-02,  1.4719e-01,  ..., -7.6498e-02,
-1.0360e+00, -5.6562e-01],
[-1.1170e+00, -6.8563e-01, -1.0466e+00,  ..., -4.8035e-02,
8.5094e-02, -6.5407e-02]],

...,

[[-5.0002e+04, -9.9101e-01, -6.7734e-01,  ..., -9.2610e-01,
-1.2709e+00, -7.3165e-01],
[ 9.9512e-02, -5.0002e+04,  3.1806e-02,  ..., -7.6206e-01,
-5.8745e-01,  9.7768e-03],
[-1.0824e+00, -6.5788e-01, -5.0002e+04,  ..., -5.9225e-01,
-6.3531e-01, -7.3157e-01],
...,
[-1.1562e-03,  7.0130e-03, -6.1178e-01,  ..., -1.1010e-03,
4.2915e-02,  5.8910e-02],
[-7.2236e-02, -7.0673e-01, -1.6531e-02,  ..., -1.2631e-02,
4.9684e-02, -1.0097e+00],
[-1.4670e+00, -7.9756e-01, -1.0041e+00,  ...,  4.1251e-02,
-5.4923e-01,  1.1196e-01]],

[[-5.0002e+04, -2.2864e-02, -7.0810e-01,  ..., -8.8877e-02,
-3.5036e-02,  8.6752e-02],
[-2.3829e-02, -5.0002e+04, -7.4363e-02,  ..., -7.0153e-01,
-5.9906e-01,  8.5919e-02],
[-1.5706e-02, -7.4920e-02, -5.0002e+04,  ..., -7.0312e-01,
-3.6608e-03, -8.4188e-02],
...,
[-6.4937e-01, -7.5320e-01, -6.7659e-01,  ..., -1.9999e-02,
-3.0025e-02, -9.6157e-02],
[-7.8045e-02, -2.7767e-02, -3.8730e-02,  ..., -6.3076e-02,
2.4677e-02, -1.0538e+00],
[-3.8923e-02, -1.2101e+00, -6.1684e-01,  ..., -7.3212e-02,
-4.1560e-02, -7.2177e-03]],

[[-5.0002e+04,  1.5744e-02, -3.6656e-02,  ...,  7.9805e-03,
6.9037e-02, -6.4532e-01],
[ 1.6047e-02, -5.0002e+04, -3.9217e-02,  ..., -6.0678e-02,
-1.0407e-01, -6.9506e-01],
[-1.1336e+00, -3.6765e-02, -5.0002e+04,  ...,  3.8987e-02,
-1.0660e+00, -9.1082e-02],
...,
[-6.3117e-01,  1.7853e-03, -6.1776e-01,  ..., -6.1914e-01,
-6.6772e-01, -8.4740e-02],
[-1.0818e+00, -7.2100e-01,  6.8703e-03,  ..., -6.4826e-01,
5.2051e-02,  8.0880e-02],
[-1.0235e+00,  8.3285e-03, -7.2880e-01,  ..., -6.3390e-01,
-6.5563e-01, -9.3100e-01]]],

[[[-5.0002e+04, -1.0287e+00,  6.8346e-03,  ...,  2.0748e+01,
-7.4605e-01,  2.0731e+01],
[-1.0326e+00, -5.0002e+04,  9.4427e-02,  ...,  2.0679e+01,
2.0148e-03, -7.3524e-02],
[ 6.7672e-03,  9.8932e-02, -5.0002e+04,  ...,  2.0788e+01,
2.0776e+01,  2.0755e+01],
...,
[-6.8427e-01, -7.6995e-01,  6.8449e-03,  ..., -6.7159e-01,
2.0742e+01, -1.8058e-02],
[ 1.0916e-01, -7.6654e-02, -7.8409e-01,  ..., -7.7711e-01,
2.0800e+01,  6.0956e-03],
[-4.2894e-02, -7.6302e-01, -6.0873e-01,  ...,  1.2333e-01,
4.8868e-02, -6.7437e-01]],

[[-5.0002e+04, -7.4121e-01,  6.8539e-03,  ..., -7.7586e-01,
-1.2751e-01, -6.7393e-01],
[-7.4487e-01, -5.0002e+04, -5.8602e-01,  ..., -1.0587e+00,
-6.9275e-01,  9.9422e-02],
[ 7.3751e-03, -9.9150e-01, -5.0002e+04,  ..., -1.0861e+00,
-9.0756e-02, -7.3040e-01],
...,
[-1.1303e+00,  1.2959e-01, -1.3156e-01,  ..., -8.1061e-03,
-5.8722e-01,  1.6247e-02],
[-2.1086e-02,  4.1514e-02,  4.0395e-04,  ..., -1.0630e+00,
-1.5798e+00, -6.9583e-01],
[-6.8207e-01, -6.7544e-01, -1.0706e+00,  ..., -6.4096e-01,
-2.5246e-02, -4.0463e-02]],

[[-5.0002e+04, -6.2382e-01, -7.4519e-01,  ...,  1.8341e-02,
-7.4715e-01, -7.3090e-01],
[-6.2022e-01, -5.0002e+04, -1.0710e+00,  ..., -6.7639e-01,
-5.5210e-02, -8.7723e-03],
[-5.8139e-02, -1.3569e+00, -5.0002e+04,  ..., -1.3842e-01,
-5.7647e-02, -7.6138e-01],
...,
[-1.2874e+00, -3.7172e-02, -1.1971e-01,  ..., -1.1135e+00,
-5.5599e-01,  4.4560e-02],
[-6.2956e-01, -6.5243e-01, -6.1504e-01,  ...,  3.3182e-02,
-1.3783e-02, -6.9928e-01],
[-1.0058e-02, -6.9272e-01, -7.2070e-01,  ..., -6.2283e-01,
-6.7011e-01, -1.2668e+00]],

...,

[[-5.0002e+04,  1.7537e-02, -7.5630e-02,  ..., -6.0353e-01,
-4.8025e-03, -6.2234e-01],
[ 1.6636e-02, -5.0002e+04, -6.2238e-01,  ..., -6.3683e-01,
-6.7922e-01, -9.3318e-02],
[-7.6349e-02, -6.1783e-01, -5.0002e+04,  ..., -6.6223e-01,
5.2067e-02, -7.7739e-01],
...,
[-7.0994e-01,  3.1005e-02, -5.0515e-02,  ...,  2.2233e-02,
1.0445e-02, -6.2186e-01],
[-4.8771e-02, -8.4967e-02, -7.1187e-02,  ..., -1.0790e+00,
2.7524e-02, -1.0127e+00],
[-7.0821e-01, -1.0360e+00, -1.1156e+00,  ..., -1.0150e+00,
3.9711e-02,  4.6916e-02]],

[[-5.0002e+04, -6.6504e-01, -7.0476e-01,  ..., -1.8005e-02,
-6.4876e-01, -2.3274e-02],
[ 2.9005e-02, -5.0002e+04, -4.0386e-02,  ...,  2.6925e-02,
1.0708e-02, -1.1037e+00],
[-7.0423e-01, -3.7335e-02, -5.0002e+04,  ..., -8.9376e-02,
-5.9300e-01, -7.5755e-01],
...,
[-1.3852e-01,  4.8598e-02, -1.0691e+00,  ..., -7.3187e-01,
-1.1090e+00,  8.3399e-02],
[ 4.7208e-02, -5.5907e-01, -1.0945e+00,  ...,  9.3991e-03,
-1.1883e-01, -1.1545e+00],
[-6.1052e-01, -6.6075e-01, -6.6978e-02,  ..., -6.1715e-02,
-4.9534e-02, -3.5774e-02]],

[[-5.0002e+04, -1.1524e+00, -3.8361e-03,  ...,  8.2180e-02,
7.1498e-02, -1.0261e+00],
[-1.1526e+00, -5.0002e+04, -6.5538e-01,  ...,  9.0787e-02,
-2.6950e-02,  8.9988e-02],
[-3.2501e-03, -6.6129e-01, -5.0002e+04,  ..., -7.3576e-02,
-6.7899e-01, -6.6380e-01],
...,
[ 3.9641e-02, -1.8851e-02, -8.0709e-02,  ..., -1.0730e+00,
1.0194e-01, -6.4881e-02],
[-7.9839e-01, -5.5824e-01, -8.1107e-01,  ..., -6.1297e-01,
1.9333e-02,  1.6730e-02],
[-7.0130e-02,  6.4639e-03, -6.0883e-01,  ..., -6.6599e-01,
-8.3970e-02, -1.4761e-02]]],

[[[-5.0002e+04, -7.0643e-01,  1.8945e-02,  ...,  2.0638e+01,
2.0691e+01,  2.0777e+01],
[-7.0638e-01, -5.0002e+04,  3.6565e-02,  ..., -7.9589e-01,
2.0858e+01, -2.3841e-04],
[ 1.7914e-02,  3.4687e-02, -5.0002e+04,  ...,  2.0742e+01,
2.0621e+01,  2.0714e+01],
...,
[ 4.6797e-02, -6.9868e-01,  8.5909e-02,  ...,  2.0669e+01,
7.8236e-02,  2.0897e+01],
[ 3.1428e-02, -7.2612e-01, -8.3142e-03,  ..., -6.7803e-01,
2.0717e+01,  4.5941e-02],
[-1.0209e+00, -5.9870e-01,  2.0145e-02,  ...,  2.0809e+01,
-9.1042e-03,  2.0818e+01]],

[[-5.0002e+04, -1.1261e+00, -6.7612e-01,  ..., -7.1700e-01,
2.3992e-02, -6.2569e-01],
[-7.2246e-01, -5.0002e+04,  6.1423e-02,  ..., -6.6325e-01,
7.7635e-02, -7.4096e-01],
[-6.7703e-01, -6.3860e-01, -5.0002e+04,  ...,  3.2460e-02,
-7.8134e-01, -1.4316e-02],
...,
[ 5.3191e-02,  2.2016e-02,  1.4697e-02,  ..., -1.0586e+00,
-8.0766e-01, -8.0925e-01],
[-5.9089e-01, -1.0910e+00, -1.0396e-02,  ...,  4.3137e-02,
-3.1703e-02,  1.9830e-01],
[-2.9503e-02,  3.5661e-02, -1.0629e-01,  ..., -1.1272e+00,
1.3117e-01, -1.5163e-02]],

[[-5.0002e+04, -7.1261e-04, -1.0960e+00,  ...,  4.6192e-02,
3.9001e-02, -6.4270e-02],
[-7.2621e-04, -5.0002e+04, -5.9984e-01,  ..., -1.0438e+00,
-5.8060e-01,  1.7381e-02],
[-1.0961e+00, -6.0384e-01, -5.0002e+04,  ..., -7.6103e-01,
3.6515e-02, -7.4593e-01],
...,
[ 3.7372e-02, -7.8854e-01, -6.5888e-01,  ..., -9.8734e-01,
-1.5392e-02, -4.7270e-02],
[ 1.0730e-01,  6.2754e-02, -1.1074e+00,  ..., -7.4467e-01,
2.4133e-02, -1.3177e+00],
[-1.6706e-02, -6.1856e-02,  6.7825e-02,  ..., -6.8134e-01,
-6.5658e-01, -6.5555e-01]],

...,

[[-5.0002e+04, -7.6860e-01, -3.9312e-02,  ..., -7.2845e-01,
6.8510e-02,  6.1583e-02],
[-1.4567e+00, -5.0002e+04,  2.9249e-02,  ..., -1.0235e+00,
-6.3884e-02, -7.4344e-01],
[-3.9045e-02,  3.1126e-02, -5.0002e+04,  ...,  4.3207e-02,
6.4832e-02, -7.6816e-01],
...,
[-7.0602e-01, -1.0803e+00, -6.8090e-03,  ..., -6.6567e-01,
3.3728e-02, -5.9617e-01],
[-1.1273e+00, -6.8315e-01, -6.5739e-01,  ..., -6.4077e-02,
-8.0328e-01, -6.8134e-01],
[-6.8594e-01, -4.0942e-02, -1.1023e+00,  ...,  1.3437e-01,
1.2066e-01, -6.2484e-02]],

[[-5.0002e+04, -6.6004e-02,  7.6812e-02,  ..., -1.7134e-02,
-2.7026e-02, -7.5400e-01],
[-6.6043e-02, -5.0002e+04, -4.1537e-02,  ..., -4.9382e-02,
-6.3732e-01, -3.7253e-02],
[ 7.1348e-02, -7.3171e-01, -5.0002e+04,  ...,  5.1813e-02,
-9.0516e-02, -2.4009e-02],
...,
[-1.0931e-01, -6.2518e-01, -7.7963e-01,  ..., -3.7894e-02,
-8.5473e-02, -7.6068e-01],
[-8.2336e-02,  2.4009e-02, -1.1192e+00,  ..., -6.7947e-01,
-2.0608e-02,  1.7448e-02],
[-9.8855e-01,  1.2772e-02, -7.5337e-01,  ..., -2.1447e-04,
1.0220e-02, -6.5317e-01]],

[[-5.0002e+04, -7.0534e-01, -1.0543e+00,  ...,  3.8087e-02,
3.0223e-02, -1.1817e-02],
[-1.1970e-02, -5.0002e+04,  6.5715e-02,  ..., -7.4209e-03,
2.7954e-02, -1.1809e+00],
[-6.5327e-01,  6.0199e-02, -5.0002e+04,  ...,  2.2354e-02,
-6.1508e-01,  1.7708e-02],
...,
[ 1.3893e-02, -1.6235e+00,  2.4286e-02,  ..., -6.4143e-01,
-1.5737e-01, -1.1110e+00],
[-1.1153e+00, -3.6544e-02, -1.9669e-02,  ..., -6.0147e-03,
-6.2389e-01, -6.8890e-01],
[-1.0333e+00, -7.9172e-01, -6.5686e-01,  ..., -6.1925e-01,
-6.6389e-02,  3.2094e-02]]],

...,

[[[-5.0002e+04,  5.2400e-02,  4.0179e-02,  ..., -7.2231e-01,
2.0599e+01,  2.0778e+01],
[ 5.0721e-02, -5.0002e+04, -7.6406e-01,  ...,  1.6523e-01,
2.6825e-02,  2.0738e+01],
[ 4.2541e-02, -1.4639e+00, -5.0002e+04,  ..., -6.9885e-01,
-1.8322e-02,  2.0849e+01],
...,
[ 1.0936e-02, -3.8569e-03,  3.8444e-02,  ...,  7.6564e-02,
-3.5330e-02, -3.7757e-02],
[ 1.5892e-02, -2.9148e-02, -8.3444e-02,  ...,  3.1712e-02,
2.0789e+01,  2.0783e+01],
[ 2.7258e-02, -1.1054e+00, -1.3504e+00,  ..., -7.6170e-01,
2.0742e+01,  2.0647e+01]],

[[-5.0002e+04,  1.3080e-03,  8.1217e-02,  ..., -1.0570e+00,
-1.1043e+00,  3.2203e-02],
[ 1.2573e-03, -5.0002e+04,  1.3325e-02,  ...,  1.9905e-02,
-6.6156e-02,  1.3769e-02],
[-6.2069e-01,  1.2368e-02, -5.0002e+04,  ..., -4.0843e-02,
2.0159e-02, -5.7616e-01],
...,
[-8.2025e-01, -4.8631e-02, -1.2133e+00,  ...,  1.0454e-02,
-6.0751e-01, -6.9148e-01],
[ 1.7746e-01,  3.9400e-02,  5.5643e-02,  ..., -1.0628e+00,
-4.7920e-02, -7.6982e-01],
[-2.8525e-02,  8.3297e-04, -7.1022e-01,  ..., -5.9549e-01,
-1.2106e-01, -2.2919e-02]],

[[-5.0002e+04, -1.1456e-02, -7.1848e-01,  ..., -3.9141e-02,
5.7795e-02, -2.4669e-02],
[-1.1195e-02, -5.0002e+04, -8.5146e-02,  ...,  6.2385e-02,
-7.1808e-01,  9.2060e-02],
[-2.7466e-02, -9.4474e-02, -5.0002e+04,  ..., -1.1161e+00,
-6.8552e-02, -6.5372e-01],
...,
[-6.9378e-01, -5.8283e-01, -7.2303e-01,  ..., -1.1824e-01,
-6.9322e-01, -2.0603e-02],
[-7.6447e-01,  5.2700e-02, -6.6736e-01,  ..., -4.6526e-02,
-6.9072e-01, -6.5179e-01],
[ 1.6763e-02, -7.0631e-01, -3.0940e-02,  ..., -3.0100e-02,
-6.7885e-01, -1.6511e-01]],

...,

[[-5.0002e+04, -5.5700e-03, -3.5600e-03,  ..., -1.1313e-01,
-6.7945e-01,  5.3541e-02],
[-5.8371e-03, -5.0002e+04, -6.8315e-01,  ..., -7.4106e-02,
-1.7017e-01, -1.1960e-04],
[-3.8548e-03, -6.8282e-01, -5.0002e+04,  ..., -4.0129e-02,
-6.2009e-01,  1.3382e-02],
...,
[ 3.6557e-02, -1.1954e+00, -5.5640e-01,  ...,  4.2153e-02,
-6.4036e-01, -2.4292e-03],
[-1.2519e-02, -1.1392e+00,  6.1505e-02,  ..., -6.0619e-01,
7.6575e-02, -7.6867e-01],
[-7.4817e-01, -7.0911e-01,  5.0159e-02,  ..., -7.8706e-02,
1.3180e-02,  4.9392e-02]],

[[-5.0002e+04,  1.0073e-02, -1.0350e+00,  ...,  1.0790e-02,
7.2497e-02, -6.2705e-01],
[-6.8405e-01, -5.0002e+04,  3.0868e-02,  ..., -8.5882e-04,
-6.6511e-01, -6.5233e-01],
[-1.0362e+00,  3.3509e-02, -5.0002e+04,  ..., -5.5677e-01,
2.5940e-02, -8.0767e-02],
...,
[-7.3452e-02, -6.2962e-02,  5.7764e-02,  ...,  3.3967e-02,
-7.1557e-01,  9.3924e-03],
[ 2.1821e-02,  1.0805e-02,  2.8415e-02,  ..., -6.7669e-01,
2.5231e-02, -1.1014e+00],
[-9.8223e-01, -4.6196e-02, -6.0867e-01,  ..., -9.9995e-01,
-2.9909e-02, -6.2392e-01]],

[[-5.0002e+04, -6.4706e-01,  9.6263e-03,  ...,  2.2836e-02,
-1.0565e+00, -1.1900e+00],
[ 5.6961e-02, -5.0002e+04,  1.5624e-01,  ..., -1.0936e-02,
-7.5872e-01, -5.6278e-01],
[ 1.1487e-02,  1.5085e-01, -5.0002e+04,  ...,  2.9795e-02,
-7.2889e-01,  2.3513e-02],
...,
[-5.9941e-03, -5.8503e-01, -7.1425e-01,  ..., -6.7522e-01,
-2.6663e-02, -6.6142e-01],
[ 1.3411e-02, -1.1007e+00,  2.1675e-02,  ...,  6.0577e-02,
-3.4467e-02, -5.5980e-01],
[-9.8314e-02, -1.0674e+00,  4.0643e-02,  ...,  7.4771e-02,
-7.7645e-02, -1.0807e+00]]],

[[[-5.0002e+04,  8.8070e-02, -7.7905e-01,  ...,  2.0628e+01,
2.0665e+01, -7.6268e-01],
[ 9.4692e-02, -5.0002e+04, -7.5558e-01,  ..., -7.6037e-02,
2.0707e+01,  2.0690e+01],
[-9.5035e-02, -1.1628e+00, -5.0002e+04,  ..., -8.3059e-03,
2.0753e+01,  2.0778e+01],
...,
[-1.0943e+00, -7.2827e-01, -8.2381e-02,  ...,  2.0671e+01,
2.0755e+01, -7.5985e-01],
[-6.7897e-01, -1.2315e-02,  1.4186e-01,  ...,  2.0691e+01,
2.0799e+01,  3.2438e-02],
[ 9.1267e-02, -5.0751e-03, -6.0676e-02,  ...,  2.0686e+01,
-4.8605e-02,  2.0722e+01]],

[[-5.0002e+04, -9.1360e-02,  3.6774e-02,  ..., -4.5187e-02,
-9.9560e-01,  3.7443e-02],
[-8.9719e-02, -5.0002e+04, -1.0807e+00,  ..., -6.1840e-01,
1.1114e-01, -2.1489e-02],
[ 3.6256e-02, -1.0807e+00, -5.0002e+04,  ..., -1.1177e+00,
-1.0720e+00, -6.6439e-01],
...,
[-6.5939e-01,  1.1382e-01, -5.5178e-01,  ..., -6.0889e-01,
-9.7330e-01, -6.8352e-01],
[-6.8595e-01,  1.1712e-01, -7.2646e-03,  ...,  8.4488e-02,
-7.0226e-01, -7.3382e-01],
[-1.0390e-02, -3.0717e-02, -7.7255e-02,  ..., -1.0751e+00,
-1.0663e-01, -3.2555e-03]],

[[-5.0002e+04, -7.7553e-01, -6.2018e-01,  ..., -7.3105e-01,
2.0386e-02,  3.8393e-02],
[-8.6760e-02, -5.0002e+04,  3.7554e-02,  ...,  1.1413e-01,
-7.6764e-02, -1.0876e+00],
[-6.2147e-01,  3.5030e-02, -5.0002e+04,  ..., -5.7204e-01,
-7.1048e-01, -6.2546e-01],
...,
[-6.8105e-01, -1.1926e+00,  4.4543e-03,  ...,  3.3833e-03,
7.4808e-02, -7.4500e-01],
[-1.0620e+00, -7.8984e-01, -7.3705e-01,  ..., -1.1305e+00,
-6.3342e-03, -3.2360e-03],
[ 5.8125e-02, -1.3288e-01, -5.9944e-01,  ..., -6.8650e-01,
-7.2090e-01,  2.9946e-02]],

...,

[[-5.0002e+04,  4.3551e-02, -1.0962e+00,  ..., -6.8317e-01,
-7.8139e-01, -9.3639e-02],
[ 4.6974e-02, -5.0002e+04,  4.3626e-03,  ..., -7.1435e-01,
-4.5532e-02, -6.9462e-01],
[-6.9041e-01,  4.4991e-03, -5.0002e+04,  ..., -1.3844e-02,
-6.9408e-01, -8.3989e-01],
...,
[ 3.9229e-02,  5.6270e-02, -6.7296e-01,  ..., -2.6244e-02,
1.7100e-02, -6.7346e-01],
[-7.3858e-01, -6.0334e-01,  4.6726e-04,  ...,  4.2693e-02,
7.6321e-02, -1.1988e+00],
[-5.0949e-03,  1.0270e-01,  1.0649e-01,  ..., -5.3066e-03,
5.2132e-02,  1.8803e-03]],

[[-5.0002e+04, -1.7331e-03, -8.4754e-01,  ..., -6.1742e-01,
-7.2971e-01, -7.6575e-02],
[-1.9064e-03, -5.0002e+04, -6.2878e-01,  ...,  6.9941e-02,
-6.7593e-01,  3.5147e-02],
[-1.5143e-01, -6.3576e-01, -5.0002e+04,  ..., -2.0979e-02,
-5.4055e-03,  8.4016e-02],
...,
[-7.5724e-01, -4.9299e-02, -6.8022e-01,  ..., -1.1208e-01,
-6.7211e-01,  4.6864e-02],
[-7.8377e-01, -7.7395e-01, -1.0539e+00,  ..., -2.9844e-02,
-8.4063e-01,  1.3819e-01],
[-9.6225e-02, -5.9687e-01,  7.6544e-02,  ..., -1.2057e-02,
1.1503e-01, -7.0326e-01]],

[[-5.0002e+04, -7.2567e-01, -6.8366e-02,  ..., -6.7452e-01,
-7.7318e-01, -6.8616e-01],
[-7.2622e-01, -5.0002e+04, -7.2297e-01,  ..., -1.0318e+00,
2.7554e-02, -1.3620e+00],
[-7.6179e-01, -1.1281e+00, -5.0002e+04,  ..., -1.0515e+00,
-6.8863e-01, -1.1271e+00],
...,
[-1.1460e-01, -5.9372e-01, -1.1546e+00,  ...,  2.2698e-02,
-1.0151e+00, -6.9752e-01],
[-7.4110e-01, -9.6526e-03,  2.2601e-02,  ..., -6.6647e-01,
-6.2969e-01, -6.2080e-02],
[-6.8643e-01, -5.5423e-03, -7.5430e-01,  ..., -9.9779e-01,
-6.3925e-01,  2.6166e-02]]],

[[[-5.0002e+04, -7.0775e-01, -7.4425e-01,  ...,  2.0819e+01,
2.0683e+01,  2.0665e+01],
[-1.4056e-02, -5.0002e+04, -1.0334e-01,  ...,  2.0730e+01,
-5.0211e-03, -1.7995e-02],
[-7.3941e-01, -9.7168e-02, -5.0002e+04,  ..., -6.9308e-01,
-2.1837e-02,  2.0897e+01],
...,
[-6.5692e-01,  1.7358e-01, -6.7614e-01,  ..., -1.1155e+00,
-2.7077e-02,  2.0594e+01],
[-3.2807e-02, -1.0641e+00, -6.9466e-01,  ...,  5.9205e-02,
2.0657e+01,  7.3325e-02],
[-7.0752e-01,  1.4747e-02, -1.1124e+00,  ...,  1.1635e-01,
2.0619e+01,  2.0606e+01]],

[[-5.0002e+04, -7.8155e-01,  4.2360e-02,  ..., -6.7046e-03,
-6.2317e-01, -6.8764e-02],
[-7.7636e-01, -5.0002e+04, -6.6514e-01,  ...,  1.6600e-02,
-6.0266e-01,  6.9248e-02],
[ 4.4054e-02, -6.6220e-01, -5.0002e+04,  ..., -4.0247e-02,
-1.0447e+00, -2.4753e-02],
...,
[-1.1650e+00, -6.5193e-01, -5.2799e-02,  ..., -7.1232e-01,
6.2070e-02, -5.7036e-01],
[ 6.6437e-02, -6.3776e-01, -6.2095e-01,  ..., -7.5672e-02,
-7.2539e-01,  3.8178e-03],
[-2.0841e-02, -7.0457e-01, -7.1138e-01,  ..., -8.5316e-01,
-6.8849e-01, -1.1017e+00]],

[[-5.0002e+04, -1.0052e-01, -1.4052e-02,  ..., -6.1361e-01,
-7.6047e-01, -6.4266e-01],
[-7.9946e-01, -5.0002e+04, -7.0874e-01,  ..., -1.3187e-02,
-7.2766e-01, -6.8447e-01],
[-1.4337e-02, -7.0819e-01, -5.0002e+04,  ...,  6.2653e-02,
8.2925e-02, -6.1827e-03],
...,
[ 5.5368e-02, -7.9218e-02, -4.9271e-02,  ..., -1.0849e+00,
-1.1037e+00, -6.8339e-01],
[-2.3466e-02, -1.2710e+00, -6.9268e-01,  ..., -1.2275e-01,
3.8920e-02, -7.5020e-01],
[ 4.9185e-03,  7.8125e-02, -1.0954e+00,  ...,  1.8172e-02,
-9.7184e-03, -7.6607e-01]],

...,

[[-5.0002e+04, -5.5086e-02,  7.3198e-03,  ..., -7.2712e-01,
-6.9818e-01, -5.7679e-01],
[-5.2186e-02, -5.0002e+04,  1.5717e-01,  ..., -6.6108e-01,
-5.1954e-02, -7.7383e-02],
[ 7.2771e-03,  1.6493e-01, -5.0002e+04,  ..., -1.7657e-02,
-5.9511e-01, -8.8096e-02],
...,
[ 5.3069e-03, -2.2161e-02, -1.0259e+00,  ..., -6.4013e-02,
-1.1779e+00, -7.8668e-01],
[ 1.0662e-01,  4.2206e-02, -5.7852e-01,  ..., -7.2492e-01,
-6.3088e-01, -5.4970e-01],
[-3.0151e-02,  1.5905e-01, -6.5035e-01,  ..., -5.0565e-02,
-1.2373e-01, -1.1663e+00]],

[[-5.0002e+04, -4.1103e-02, -2.3522e-02,  ..., -6.5312e-01,
-5.2362e-06,  8.0098e-04],
[-7.4138e-01, -5.0002e+04, -9.3284e-01,  ...,  7.8361e-02,
-5.8864e-01, -7.6385e-01],
[-1.1233e+00, -9.5003e-01, -5.0002e+04,  ..., -6.1614e-01,
-5.3328e-02, -7.3725e-01],
...,
[-1.4016e+00, -3.7399e-02, -7.3395e-01,  ..., -6.1674e-01,
-6.8356e-02, -7.4714e-01],
[-1.1417e+00, -6.6967e-01, -7.7276e-01,  ..., -6.3851e-01,
-6.8503e-01, -6.3605e-01],
[-4.3367e-02,  1.3107e-01, -6.3032e-01,  ...,  2.9499e-03,
-7.3189e-01, -6.3194e-03]],

[[-5.0002e+04,  1.6120e-01, -1.0919e+00,  ...,  3.7141e-02,
-6.4391e-01, -6.6716e-01],
[ 1.6676e-01, -5.0002e+04,  2.1009e-02,  ..., -1.4134e-02,
-2.4726e-02,  1.4608e-01],
[-6.8621e-01, -6.7221e-01, -5.0002e+04,  ...,  3.3562e-02,
-1.3796e-02, -8.1103e-02],
...,
[-5.9770e-01, -7.5384e-01, -6.6301e-01,  ..., -6.2909e-01,
2.6888e-02, -6.9892e-01],
[ 8.2479e-02,  1.5413e-02,  1.0300e-01,  ..., -7.2118e-01,
2.8926e-02,  7.7242e-03],
[-6.9020e-01,  7.0585e-02, -6.1188e-01,  ..., -5.4735e-01,
-1.3714e-01, -2.1106e-02]]]])

del mask


Add casual masking to not allow the model to peek into future tokens. Mask if query comes after key in input sequence.

bq_t[0,55,:]

tensor([220, 226, 255, 261, 268, 274, 290, 313, 316, 317, 322, 323, 338, 356,
368, 407, 410, 437, 446, 447, 500,   1,   7,  24,  25,  27,  46,  47,
73,  96, 104, 109])
bkv_t[0,55,:]

tensor([220, 226, 255, 261, 268, 274, 290, 313, 316, 317, 322, 323, 338, 356,
368, 407, 410, 437, 446, 447, 500,   1,   7,  24,  25,  27,  46,  47,
73,  96, 104, 109, 283, 301, 302, 331, 337, 344, 361, 421, 432, 456,
464, 465, 470, 476, 497, 501,   6,   9,  10,  11,  30,  32,  66,  69,
72,  79,  90, 122, 157, 170, 190, 206])
#        if self.causal:
mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]

torch.Size([64, 96, 32, 64])

Since the tokens are sorted before masking, the masking is not triangular like in the base transformer: in this case all tokens after 220 should be masked (True):

[(t,m) for t,m in zip(bkv_t[0,55,:], mask[0,55,0,:])]

[(tensor(220), tensor(False)),
(tensor(226), tensor(True)),
(tensor(255), tensor(True)),
(tensor(261), tensor(True)),
(tensor(268), tensor(True)),
(tensor(274), tensor(True)),
(tensor(290), tensor(True)),
(tensor(313), tensor(True)),
(tensor(316), tensor(True)),
(tensor(317), tensor(True)),
(tensor(322), tensor(True)),
(tensor(323), tensor(True)),
(tensor(338), tensor(True)),
(tensor(356), tensor(True)),
(tensor(368), tensor(True)),
(tensor(407), tensor(True)),
(tensor(410), tensor(True)),
(tensor(437), tensor(True)),
(tensor(446), tensor(True)),
(tensor(447), tensor(True)),
(tensor(500), tensor(True)),
(tensor(1), tensor(False)),
(tensor(7), tensor(False)),
(tensor(24), tensor(False)),
(tensor(25), tensor(False)),
(tensor(27), tensor(False)),
(tensor(46), tensor(False)),
(tensor(47), tensor(False)),
(tensor(73), tensor(False)),
(tensor(96), tensor(False)),
(tensor(104), tensor(False)),
(tensor(109), tensor(False)),
(tensor(283), tensor(True)),
(tensor(301), tensor(True)),
(tensor(302), tensor(True)),
(tensor(331), tensor(True)),
(tensor(337), tensor(True)),
(tensor(344), tensor(True)),
(tensor(361), tensor(True)),
(tensor(421), tensor(True)),
(tensor(432), tensor(True)),
(tensor(456), tensor(True)),
(tensor(464), tensor(True)),
(tensor(465), tensor(True)),
(tensor(470), tensor(True)),
(tensor(476), tensor(True)),
(tensor(497), tensor(True)),
(tensor(501), tensor(True)),
(tensor(6), tensor(False)),
(tensor(9), tensor(False)),
(tensor(10), tensor(False)),
(tensor(11), tensor(False)),
(tensor(30), tensor(False)),
(tensor(32), tensor(False)),
(tensor(66), tensor(False)),
(tensor(69), tensor(False)),
(tensor(72), tensor(False)),
(tensor(79), tensor(False)),
(tensor(90), tensor(False)),
(tensor(122), tensor(False)),
(tensor(157), tensor(False)),
(tensor(170), tensor(False)),
(tensor(190), tensor(False)),
(tensor(206), tensor(False))]
#mask = mask & (bkv_t[:, :, None, :] < query_len)


The paper states: While attention to the future is not allowed, typical implementations of the Transformer do allow a position to attend to itself. Such behavior is undesirable in a shared-QK formulation because the dot-product of a query vector with itself will almost always be greater than the dot product of a query vector with a vector at another position. We therefore modify the masking to forbid a token from attending to itself, except in situations where a token has no other valid attention targets (e.g. the first token in a sequence).

Recall that bq_t and bkc_t is the look up key we use to reorder input qk and v into sorted and chunked order. The diagonal of the first part of k is eg. always similar to q. We have to compare all elements with each other (32*64)

bq_t[0,0,:], bkv_t[0,0,:]

(tensor([ 20,  29,  87, 112, 118, 168, 177, 182, 202, 271, 288, 319, 345, 357,
365, 366, 374, 377, 383, 410, 419, 420, 447, 469,   5,  23,  46,  55,
66,  93,  96, 102]),
tensor([ 20,  29,  87, 112, 118, 168, 177, 182, 202, 271, 288, 319, 345, 357,
365, 366, 374, 377, 383, 410, 419, 420, 447, 469,   5,  23,  46,  55,
66,  93,  96, 102, 368, 396, 410, 423, 425, 446, 456, 472, 491, 494,
20,  38, 100, 105, 135, 189, 210, 241, 242, 285, 286, 304, 306, 322,
346, 367, 404, 450, 480, 485, 498, 500]))

We achieve this by adding appropriate unit axis (se section below):

bq_t[:, :, :, None].shape, bkv_t[:, :, None, :].shape

(torch.Size([64, 96, 32, 1]), torch.Size([64, 96, 1, 64]))
self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]

(torch.Size([64, 96, 32, 64]), tensor(34))
dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE)

tensor([[[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.4651e-01, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.1882e-01,  2.4822e-02, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[ 8.9638e-02, -7.2742e-01, -7.8459e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-8.0705e-02,  3.9972e-02, -9.7325e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-7.0885e-01, -1.1693e+00, -1.1712e+00,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ...,  6.0113e-02,
-6.6529e-01, -7.0489e-01],
[-2.0178e-02, -5.0000e+04, -3.4028e+38,  ..., -7.3383e-01,
3.2632e-02, -1.1272e+00],
[ 3.3639e-02, -1.0865e+00, -5.0000e+04,  ..., -1.1597e+00,
-1.4506e-01, -6.3097e-01],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -2.0544e-02,
-6.6849e-01, -3.1145e-03],
[-1.1505e+00, -5.0000e+04, -3.4028e+38,  ..., -5.2326e-02,
-6.3007e-01, -2.3459e-02],
[ 4.2065e-02, -5.8718e-01, -5.0000e+04,  ..., -1.1978e-01,
-9.9706e-01, -4.5898e-02],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -5.5354e-02,
2.8623e-02, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -7.3655e-01,
-2.4440e-02, -3.4028e+38],
[ 3.8226e-02,  5.8523e-02, -7.1844e-01,  ...,  1.1530e-01,
7.8461e-02, -8.4067e-02]],

...,

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -5.7121e-01,
-3.8492e-03,  1.0117e-01],
[-7.0567e-01, -5.0000e+04, -3.4028e+38,  ..., -9.9979e-01,
-6.8849e-01, -7.3988e-02],
[-6.3723e-01,  1.8232e-02, -5.0000e+04,  ...,  3.3746e-02,
-6.0018e-01, -1.1584e+00],
...,
[ 6.1799e-02,  6.9101e-03,  3.7042e-02,  ...,  2.6486e-02,
-6.1869e-01, -7.4292e-01],
[-3.8282e-02,  7.3337e-02, -1.0109e-01,  ...,  5.5641e-02,
-6.4948e-01,  5.4328e-02],
[-9.8530e-02, -3.5202e-02, -1.5972e+00,  ..., -1.2660e-01,
-7.4450e-01, -1.0500e+00]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -6.6170e-01,
-9.9171e-01, -6.4411e-02],
[ 2.8565e-02, -5.0000e+04, -3.4028e+38,  ...,  5.4141e-02,
-3.5999e-02,  3.9631e-02],
[-1.2061e+00, -3.8697e-02, -5.0000e+04,  ...,  4.6656e-02,
-6.1919e-01, -7.5294e-01],
...,
[-1.4103e+00, -7.2197e-01, -6.3220e-01,  ..., -6.5529e-01,
-7.4984e-01,  4.3277e-02],
[-2.8961e-02, -7.4635e-01,  1.7832e-02,  ..., -5.7393e-02,
-6.4106e-01,  1.1850e-02],
[-6.2059e-02, -3.3901e-02, -1.1619e-02,  ...,  1.4928e-02,
5.3628e-04, -7.0925e-01]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -1.9251e-02,
5.2760e-02, -6.8958e-02],
[-6.8157e-01, -5.0000e+04, -3.4028e+38,  ..., -5.8809e-01,
4.5345e-03, -5.4893e-01],
[-5.7760e-01, -6.0494e-01, -5.0000e+04,  ..., -6.8532e-01,
-7.2322e-01,  5.7745e-02],
...,
[ 7.2948e-03, -6.7351e-01,  6.9916e-02,  ..., -6.7309e-01,
-7.1030e-01,  1.0441e-02],
[-6.2357e-01, -7.3820e-01, -6.2938e-01,  ..., -6.8691e-02,
-1.3458e+00, -8.4144e-01],
[-8.7684e-01, -6.4115e-01, -6.4251e-01,  ...,  7.9770e-02,
-3.5258e-02,  1.1269e-01]]],

[[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.6076e-01, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.2430e-01, -1.2219e-01, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-8.6926e-01,  6.5723e-02, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.4740e-02, -1.1260e+00, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.6764e-01, -8.3105e-02, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ...,  7.7333e-02,
3.1600e-02,  1.8005e-02],
[-6.7779e-01, -5.0000e+04, -3.4028e+38,  ..., -1.0047e+00,
-7.1437e-01, -1.2170e-02],
[-1.1878e-02,  1.3605e-01, -5.0000e+04,  ..., -2.3003e-02,
1.3730e-02, -7.3811e-01],
...,
[-7.4399e-01, -7.2079e-01,  1.6003e-02,  ..., -1.2630e+00,
-4.9323e-02,  7.9365e-02],
[-1.3635e+00, -6.7924e-01, -2.3513e-02,  ..., -1.3940e+00,
-3.1175e-02, -2.2558e-02],
[-6.5811e-01, -1.0604e+00, -9.5996e-01,  ..., -7.6033e-01,
-5.5959e-01,  2.1492e-02]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -6.1800e-01,
-6.2581e-01,  5.1135e-02],
[-9.6606e-01, -5.0000e+04, -3.4028e+38,  ..., -6.0455e-01,
-7.3208e-01, -4.9317e-02],
[-3.2136e-02, -7.2623e-03, -5.0000e+04,  ...,  1.4396e-01,
-6.6859e-01, -6.1585e-01],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ...,  1.5643e-02,
-3.4028e+38, -3.4028e+38]],

...,

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ...,  1.5109e-01,
-6.7132e-01, -7.1930e-01],
[-6.4984e-01, -5.0000e+04, -3.4028e+38,  ...,  1.3963e-03,
-1.1509e+00, -4.5689e-02],
[-1.6850e-02, -1.3503e+00, -5.0000e+04,  ...,  8.6320e-02,
-6.6129e-01, -6.6410e-01],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -1.1687e+00,
-7.9722e-01,  5.0385e-02],
[ 4.2493e-02, -5.0000e+04, -3.4028e+38,  ..., -7.5008e-01,
1.0685e-01, -6.0812e-01],
[-6.5271e-01, -1.1316e+00, -5.0000e+04,  ..., -8.8241e-03,
1.7627e-01, -3.8736e-02],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ...,  1.0364e-01,
7.5225e-02,  7.8665e-02],
[-1.0867e-01, -5.0000e+04, -3.4028e+38,  ..., -1.5266e-03,
-6.3571e-01,  3.4974e-02],
[ 1.1809e-02, -1.0740e+00, -5.0000e+04,  ..., -6.9976e-03,
-6.9689e-01, -1.1085e+00],
...,
[-1.0437e-01, -1.2239e-02, -6.6529e-01,  ...,  2.5524e-02,
3.4716e-02,  8.2917e-02],
[-2.2990e-02, -6.3193e-01, -1.1988e+00,  ..., -7.3525e-01,
-6.0445e-02, -4.4580e-02],
[-6.6057e-01, -3.8727e-02, -5.7678e-02,  ..., -1.1243e+00,
6.1690e-02, -6.3970e-01]]],

[[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-7.0107e-01, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.7432e-01, -5.8993e-01, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-1.0163e+00, -5.7144e-02,  6.1773e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.7664e-01, -6.7368e-01, -1.1599e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[ 2.0432e-02, -1.1016e+00,  5.7354e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -1.0858e+00,
3.1600e-02, -1.0709e+00],
[-5.6724e-01, -5.0000e+04, -3.4028e+38,  ...,  1.1074e-03,
-1.1556e+00,  1.6300e-02],
[-1.0745e+00, -6.8645e-01, -5.0000e+04,  ..., -6.0015e-01,
-7.3519e-01,  7.4664e-02],
...,
[-1.0592e+00, -1.0839e-01, -8.5495e-02,  ..., -1.1811e-01,
-6.6579e-01, -1.0507e+00],
[-7.0449e-01,  2.2444e-02, -6.2353e-01,  ...,  8.6373e-02,
3.3366e-02, -7.6842e-01],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -7.4734e-01],
[-7.3401e-01, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -7.3220e-01],
[-1.1460e+00, -9.8449e-02, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -7.0000e-01],
...,
[ 5.0304e-02,  2.1384e-02, -5.5987e-02,  ..., -5.6092e-02,
-3.4028e+38,  6.2510e-02],
[ 7.9292e-03, -5.9112e-01, -1.1575e+00,  ...,  6.3428e-02,
-1.9939e-02, -6.2542e-01],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -1.7875e-01]],

...,

[[-5.0000e+04, -3.4028e+38, -7.1105e-01,  ...,  2.1309e-02,
-6.9780e-01,  1.1153e-01],
[-1.2436e-01, -5.0000e+04, -4.3738e-02,  ...,  6.7087e-02,
-7.3648e-01, -1.0334e+00],
[-3.4028e+38, -3.4028e+38, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-8.3788e-01, -7.0348e-01, -1.0948e+00,  ..., -6.6368e-01,
-1.3585e+00,  6.5881e-02],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-9.0959e-03, -3.0965e-02],
[-1.0970e+00, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-1.0928e-01, -6.8574e-01],
[-6.4475e-01, -6.4413e-02, -5.0000e+04,  ..., -3.4028e+38,
-1.1016e+00, -7.0929e-01],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-1.1587e+00, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-7.8249e-01, -7.3240e-01],
[-6.3891e-01,  3.5043e-02, -1.1211e+00,  ..., -3.4028e+38,
-4.2181e-03, -6.6783e-01]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ...,  6.6128e-02,
5.7813e-04, -6.3975e-01],
[-1.1157e+00, -5.0000e+04, -3.4028e+38,  ...,  5.5402e-02,
-6.8001e-01, -4.1885e-02],
[-3.1970e-02, -6.2763e-01, -5.0000e+04,  ...,  8.5007e-02,
-8.2336e-02, -6.7272e-01],
...,
[-7.1669e-01, -2.6079e-02, -1.0298e-01,  ..., -1.1081e+00,
-1.0416e+00, -1.1158e+00],
[-7.3221e-02, -1.8096e-02, -5.7792e-01,  ..., -8.2817e-01,
-5.4756e-02,  6.0642e-02],
[-6.2527e-01, -7.3235e-01,  3.6161e-02,  ..., -6.5685e-01,
-1.0227e+00,  4.3928e-02]]],

...,

[[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[ 2.6721e-03, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-2.7559e-02, -1.0628e+00, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-5.6986e-01, -1.1886e-02, -7.5952e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.0231e+00, -1.7792e-02, -5.0195e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.1377e+00, -7.1819e-01, -6.3855e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -1.0323e+00,
-9.5445e-01, -1.1035e+00],
[-8.1883e-01, -5.0000e+04, -3.4028e+38,  ..., -1.1960e-01,
1.3267e-01,  3.6285e-02],
[-3.4274e-02,  5.2932e-02, -5.0000e+04,  ...,  3.8055e-02,
-1.3925e-02,  2.2453e-02],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -6.2002e-01,
8.9151e-03,  1.0790e-01],
[-2.7135e-02, -5.0000e+04, -3.4028e+38,  ..., -6.6650e-01,
-9.2817e-01, -2.2317e-02],
[ 8.0496e-02, -6.8004e-01, -5.0000e+04,  ...,  4.2166e-02,
-3.6741e-02, -5.3699e-01],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

...,

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -5.5951e-01,
3.6174e-02, -6.4231e-01],
[-7.4560e-01, -5.0000e+04, -3.4028e+38,  ..., -6.3169e-01,
-1.4407e+00, -7.1516e-01],
[-5.7026e-01, -7.8466e-01, -5.0000e+04,  ..., -9.9972e-03,
-6.0130e-02,  1.1593e-01],
...,
[-7.3289e-01,  2.0110e-02, -7.2580e-01,  ..., -1.0738e-02,
-1.1695e-01, -1.0707e+00],
[-1.0875e+00,  3.9373e-02,  1.4219e-02,  ..., -6.3023e-01,
5.5321e-02, -5.9529e-01],
[-5.9247e-01,  1.2369e-01, -7.3387e-02,  ..., -7.6871e-02,
-7.2903e-01, -6.1954e-02]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -1.4231e-04,
2.5220e-02, -1.0773e+00],
[-6.3603e-01, -5.0000e+04, -3.4028e+38,  ..., -1.0874e+00,
-4.9859e-02, -6.4749e-01],
[-6.4161e-01, -7.4738e-01, -5.0000e+04,  ..., -5.7791e-01,
-6.2543e-03, -1.1642e+00],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -7.0094e-01,
4.2724e-03,  5.4235e-02],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -2.3470e-02,
-1.0217e-01, -6.4626e-01],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -7.0674e-01,
-6.3122e-01, -1.0745e-02]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -6.2641e-01,
4.7402e-02,  8.4206e-02],
[-7.0791e-01, -5.0000e+04, -3.4028e+38,  ..., -4.3338e-02,
-2.8378e-02,  1.2988e-01],
[ 1.0670e-04,  8.5781e-02, -5.0000e+04,  ..., -6.3421e-01,
-1.2349e+00, -7.5487e-01],
...,
[-2.5026e-03, -6.0923e-01,  9.6376e-02,  ...,  1.4107e-01,
1.2839e-02, -5.6102e-01],
[-6.8853e-01, -9.3420e-02, -5.8252e-02,  ...,  2.7106e-02,
-6.8703e-01,  8.3605e-02],
[-7.6307e-01, -6.7374e-01, -2.8049e-02,  ..., -1.0166e-01,
-6.2191e-01,  2.8614e-02]]],

[[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.0153e+00, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-8.1250e-02, -7.3463e-01, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-6.7312e-01, -1.0561e-02, -6.1704e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[ 3.9926e-02,  4.2183e-03, -7.0787e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.4681e-01, -7.6523e-01, -7.3376e-01,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38,  2.7195e-02,  ..., -5.2110e-02,
-6.0999e-02,  4.4449e-02],
[-1.3868e-01, -5.0000e+04, -4.3554e-02,  ..., -3.3023e-02,
1.8760e-02, -7.5677e-01],
[-3.4028e+38, -3.4028e+38, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[ 5.7223e-04, -3.4028e+38, -6.2095e-01,  ..., -7.2350e-01,
-7.8402e-01, -6.9741e-01],
[-8.4171e-01, -3.4028e+38, -5.2057e-02,  ..., -4.2653e-02,
3.0755e-02, -6.8030e-02],
[-6.5412e-01, -3.4028e+38, -9.7690e-01,  ..., -6.2359e-01,
-1.6402e-02, -5.3178e-02]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.9276e-02, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-6.8574e-01, -5.7047e-01, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-7.4750e-01,  1.0921e-01, -1.0828e+00,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[ 4.8119e-02, -5.7358e-01, -1.0990e+00,  ...,  5.4826e-02,
-3.4028e+38, -3.4028e+38],
[-6.2310e-01, -1.0785e+00, -1.0339e+00,  ..., -1.3479e-02,
2.0068e-01, -3.4028e+38]],

...,

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ...,  1.0200e-02,
-6.7591e-01, -1.1870e+00],
[ 1.5608e-02, -5.0000e+04, -3.4028e+38,  ..., -5.0493e-01,
-9.9869e-01, -1.4343e+00],
[-8.0914e-02,  2.9625e-02, -5.0000e+04,  ..., -7.0351e-01,
-6.4617e-01,  5.6308e-03],
...,
[-8.9133e-03, -1.3096e-01, -4.7320e-04,  ..., -7.2473e-01,
-1.6868e-02,  2.8787e-02],
[-7.9361e-01, -6.4984e-01,  7.5883e-02,  ..., -4.4323e-03,
-6.8407e-02, -1.0846e-01],
[ 5.3698e-02, -7.0342e-01, -8.5320e-02,  ..., -4.1425e-02,
-7.9851e-01, -6.7153e-01]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[ 1.1630e-01, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[ 7.7609e-02,  1.4753e-02, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-6.1882e-02, -1.4174e-02, -1.3796e+00,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-2.5393e-02, -7.5303e-01,  1.0076e-01,  ...,  6.1855e-02,
-7.9291e-01, -3.4028e+38],
[-6.2580e-01, -7.6510e-01,  1.8015e-02,  ..., -6.0075e-01,
-8.3052e-01, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -6.4054e-01,
-8.7310e-02,  5.3704e-02],
[-6.4269e-01, -5.0000e+04, -3.4028e+38,  ..., -8.2971e-02,
-1.4855e-01, -8.9776e-02],
[ 2.3660e-02, -7.1981e-01, -5.0000e+04,  ...,  9.7790e-02,
-4.8083e-02, -6.1394e-01],
...,
[-1.3635e+00, -6.5168e-01, -1.2835e-02,  ..., -6.8055e-01,
3.6372e-02, -6.5923e-01],
[ 5.5290e-02, -7.6812e-01, -1.1189e+00,  ..., -6.3953e-01,
-6.5224e-01, -7.0943e-01],
[ 1.3939e-02, -8.7076e-02, -6.2809e-01,  ..., -6.5883e-01,
-8.7338e-02, -7.4225e-01]]],

[[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.4639e+00, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-1.0586e+00, -1.1582e+00, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[ 5.8896e-02, -5.7658e-01, -7.5433e-01,  ...,  2.0647e+01,
-2.6651e-02,  2.0653e+01],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-7.2978e-01, -1.0459e+00,  6.2408e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-7.2156e-01,  1.8013e-02],
[-1.5039e-02, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
-3.7105e-02, -5.8817e-01],
[-3.9265e-02, -6.1889e-02, -5.0000e+04,  ..., -3.4028e+38,
-3.4280e-02, -5.5754e-01],
...,
[-7.2015e-01, -8.7395e-02, -1.2293e-02,  ..., -3.4028e+38,
-6.6509e-01,  8.0814e-02],
[ 4.6665e-02, -1.0344e+00, -7.1668e-01,  ..., -3.4028e+38,
-1.0477e+00, -6.8540e-01],
[-4.9705e-01,  7.4967e-03, -1.0317e+00,  ..., -3.4028e+38,
-5.8730e-02,  4.6460e-02]],

[[-5.0000e+04, -3.4028e+38, -8.5118e-01,  ..., -3.2423e-03,
-4.3115e-02,  4.8089e-02],
[-1.0768e+00, -5.0000e+04, -1.3421e-02,  ..., -4.7596e-02,
-5.9070e-02,  8.5101e-02],
[-3.4028e+38, -3.4028e+38, -5.0000e+04,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
...,
[-3.4028e+38, -3.4028e+38, -6.0726e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38,  2.6122e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -4.6727e-02,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

...,

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-2.7874e-02, -1.1563e+00],
[-5.7874e-01, -5.0000e+04, -3.4028e+38,  ..., -3.4028e+38,
5.7971e-02, -3.7483e-02],
[ 7.5618e-03, -6.5969e-01, -5.0000e+04,  ..., -3.4028e+38,
-7.3267e-01,  1.2819e-02],
...,
[-6.5722e-01,  1.0743e-01,  3.1043e-02,  ..., -3.4028e+38,
-5.9182e-01,  1.4952e-01],
[-6.9334e-01, -1.0488e+00, -6.3029e-02,  ..., -3.4028e+38,
-6.6509e-01,  8.8972e-03],
[-8.1567e-01,  9.3357e-03, -5.8275e-02,  ..., -3.4028e+38,
7.6341e-02,  1.6951e-01]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -4.5265e-03,
-7.1776e-01, -7.0639e-01],
[-6.6058e-01, -5.0000e+04, -3.4028e+38,  ...,  4.5980e-02,
-4.9668e-02, -1.0652e-02],
[-7.0612e-01, -1.0613e+00, -5.0000e+04,  ..., -7.4379e-01,
5.2316e-02, -6.9348e-01],
...,
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38],
[-3.4028e+38, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38,
-3.4028e+38, -3.4028e+38]],

[[-5.0000e+04, -3.4028e+38, -3.4028e+38,  ..., -3.6865e-02,
-6.0356e-01, -3.1227e-02],
[ 6.7016e-02, -5.0000e+04, -3.4028e+38,  ..., -1.2225e+00,
-1.3129e-02, -1.1694e+00],
[-6.3382e-01, -6.9378e-01, -5.0000e+04,  ..., -1.0572e+00,
3.5980e-02, -6.9805e-01],
...,
[-7.4151e-01, -1.0349e+00, -2.2069e-02,  ..., -9.5842e-02,
3.5878e-03, -6.5626e-01],
[-6.8698e-01, -3.7117e-02, -1.1274e+00,  ..., -6.7698e-01,
-3.0499e-02, -2.4080e-02],
[-1.1675e-01, -7.4351e-03,  1.2485e-01,  ...,  9.3894e-02,
1.1018e-01, -7.3857e-01]]]])
##### Comparing all elements by adding unit axis

By adding appropriate unit axis, we can compare all elements of the final dimension of a tensor

a, b = torch.arange(2), torch.arange(5)
a.shape, b.shape

(torch.Size(), torch.Size())
a, b, a[None,:]+b[:,None], a[:,None]+b[None,:]

(tensor([0, 1]),
tensor([0, 1, 2, 3, 4]),
tensor([[0, 1],
[1, 2],
[2, 3],
[3, 4],
[4, 5]]),
tensor([[0, 1, 2, 3, 4],
[1, 2, 3, 4, 5]]))

#### Mask out attention to other hash buckets.

Note: The paper sugests NOT attenting across buckets: Now we turn to LSH attention, which we can think of in terms of restricting the set Pi of target items a query position i can attend to, by only allowing attention within a single hash bucket. Lucidrains' inmplementation sets this to True by default however.

We will only run this part of the code if we want to restrict attention across buckets.

First we get the hasbucket by integer dividing by sl. Note that hasbucket ids are consecutive across hash rounds (not overlapping in each round). We also reshape to n_chunks:

#if not self._attend_across_buckets:
bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // sl, (bs, n_chunks, -1))
bq_buckets.shape

torch.Size([64, 96, 32])

Eg. attention chunk 1 has a mix of buckets 0 and 1:

bq_buckets[0,1,:]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 2, 2, 2, 2, 2])

We add previous chunk to the keys:

bkv_buckets = look_one_back(bkv_buckets)
bkv_buckets.shape

torch.Size([64, 96, 64])
bkv_buckets[0,1,:]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :]

torch.Size([64, 96, 32, 64])
bq_buckets[0,1,-1], bkv_buckets[0,1,:]

(tensor(2),
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]))
bucket_mask[0,1,-1,:]

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
True,  True,  True,  True,  True,  True,  True, False, False, False,
False, False,  True,  True,  True,  True,  True,  True,  True,  True,
True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
True,  True,  True,  True])
dots.masked_fill_(bucket_mask, masked_value)


#### Softmax

We take the softmax with the logsumexp trick:

dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
dots = torch.exp(dots - dots_logsumexp).type_as(dots)
dots.shape

torch.Size([64, 96, 32, 64])
dots[0,0,0,:].sum()

tensor(1.)

And finally compute our self attention:

bo = torch.einsum('bnsz,bnzd->bnsd',
dots,                  # [bs, n_chunks, chunk_size, chunk_size*2]
bv)                    # [bs, n_chunks, chunk_size*2, model_dim]
bo.shape                                 # [bs, n_chunks, chunk_size, model_dim]

torch.Size([64, 96, 32, 256])

#### Unsorting

The final step is to reconstruct the batched, chunked and sorted q, k and v back to our original representation. First we reshape the contextualised values to remove the n_chunks dimension. It's still sorted though:

First we reshape self-attnetion to remove the n_chunks dimension. It's still sorted though:

so = torch.reshape(bo, (bs, -1, dim))                 # [bs, seqlen*n_rounds, model_dim]
so.shape

torch.Size([64, 3072, 256])

Then we unsort so by looking up our unsort keys undo_sort.

o = batched_index_select(so, undo_sort)
o.shape

torch.Size([64, 3072, 256])

And reshapes it to include a n_rounds dimension:

o = torch.reshape(o, (bs, n_rounds, sl, dim))    # [bs, n_rounds, sl, dim]
o.shape

torch.Size([64, 6, 512, 256])

Then the same steps for the logits (dots_logsumexp):

slogits = torch.reshape(dots_logsumexp, (bs, -1,))    # [bs, seqlen*n_rounds]
slogits.shape

torch.Size([64, 3072])
logits = slogits.gather(1, undo_sort)
logits.shape                                    # [bs, seqlen*n_rounds]

torch.Size([64, 3072])
logits = torch.reshape(logits, (bs, n_rounds, sl, 1))
logits.shape                                       # [bs, n_rounds, sl, 1]

torch.Size([64, 6, 512, 1])

We take the softmax over the n_rounds dimension, "averaging" the contribution to self-attention over each hashing round.

probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True))
probs.shape           # [bs, n_rounds, sl, 1]

torch.Size([64, 6, 512, 1])

So summing over the n_rounds dimension equals 1:

probs[0,:,0,0].sum()

tensor(0.9988)

And our final self-attention weighted by contribution from each round:

out = torch.sum(o * probs, dim=1)      # [bs, sl, dim]
out.shape, qk.shape, v.shape

(torch.Size([64, 512, 256]),
torch.Size([64, 512, 256]),
torch.Size([64, 512, 256]))

### LSHAttention - minimal implementation

class LSHAttention(nn.Module):
def __init__( self, bucket_size = 64, n_hashes = 8):
super().__init__()

self.bucket_size = bucket_size
self.n_hashes = n_hashes

def forward(self, qk, v, **kwargs):
batch_size, seqlen, dim, device = *qk.shape, qk.device

#pdb.set_trace()
# f'Sequence length ({seqlen}) needs to be divisible by target bucket size  x 2 - {self.bucket_size * 2}'
assert seqlen % (self.bucket_size * 2) == 0

# Get buckets. We use the above method
n_buckets = seqlen // self.bucket_size
buckets = hash_vectors(qk, n_buckets, self.n_hashes)        # buckets: [bs, (sl * n_hashes)]

# We use the same vector as both a query and a key.
assert int(buckets.shape) == self.n_hashes * seqlen

# a vector of [bs, n_hashes*seqlen), where ticker[0,:]= [0,1,2, ..-, seqlen*n_hash-1]
ticker = torch.arange(self.n_hashes * seqlen, device=device).unsqueeze(0).expand_as(buckets)

# ticker % seqlen = [o...seqlen-1, 0...seqlen-1, ...] n_bucket times
# we add the bucket id scaled by seqlen
# shape: [bs, (seqlen*buckets)]
# let us sort according to bucket id and index in sequence
buckets_and_t = seqlen * buckets + (ticker % seqlen)

#buckets_and_t = buckets_and_t.detach()

# Hash-based sort ("s" at the start of variable names means "sorted")
sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)     # shapes are [bs, seqlen*n_hashes]
_, undo_sort = sticker.sort(dim=-1)                                       # indexes to undo sortings
del ticker

st = (sticker % seqlen)              # index of [0..seqlen-1] for each hash round (n_hashes)[bs, seqlen*n_hashes]
sqk = batched_index_select(qk, st)   # get the sorted qk, [bs, seqlen, model_dim]
sv = batched_index_select(v, st)     # get the sorted v, [bs, seqlen, model_dim]

# Split off a "bin" axis so that attention only occurs within chunks.
# get the qk and v chunks and also the indexes to undo sort later
n_chunks = self.n_hashes * n_buckets
bq_t = bkv_t = torch.reshape(st, (batch_size, n_chunks, -1))   # [bs, n_chunks, chunk_size]
bqk = torch.reshape(sqk, (batch_size, n_chunks, -1, dim))      # [bs, n_chunks, chunk_size, model_dim]
bv = torch.reshape(sv, (batch_size, n_chunks, -1, dim))        # [bs, n_chunks, chunk_size, model_dim]

# Hashing operates on unit-length vectors. Unnormalized query vectors are
# fine because they effectively provide a learnable temperature for the
# attention softmax, but normalizing keys is needed so that similarity for
# the purposes of attention correctly corresponds to hash locality.
bq = bqk
bk = F.normalize(bqk, p=2, dim=-1).type_as(bq)

## attent to previous chunk as well - append previous chunk, cat along dim=2 (the sl dimension)
bk = look_one_back(bk)            # [bs, n_chunks, chunk_size*2, model_dim]
bv = look_one_back(bv)            # [bs, n_chunks, chunk_size*2, model_dim]
bkv_t = look_one_back(bkv_t)      # [bs, n_chunks, chunk_size*2, model_dim]

# Dot-product attention
dots = torch.einsum('bnsd,bnzd->bnsz',
bq,                  # [bs, n_chunks, chunk_size, model_dim]
bk                   # [bs, n_chunks, chunk_size*2, model_dim]
) * (dim ** -0.5)     # dots: [bs, n_chunks, chunk_size, chunk_size*2]

# Mask out attention to self except when no other targets are available.
self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]

# Softmax.
dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
dots = torch.exp(dots - dots_logsumexp).type_as(dots)

# calculate self-attention (attn * values)
bo = torch.einsum('bnsz,bnzd->bnsd', dots, bv)                 # [bs, n_chunks, chunk_size, model_dim]

# unchunk, unsort and reshape self-attention
so = torch.reshape(bo, (batch_size, -1, dim))                  # [bs, seqlen*n_hashes, model_dim]
o = batched_index_select(so, undo_sort)                        # [bs, seqlen*n_hashes, model_dim]
o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, dim)) # [bs, n_hashes, seqlen, model_dim]

# unchunk, unsort and reshape logits
slogits = torch.reshape(dots_logsumexp, (batch_size, -1,))              # [bs, seqlen*n_hashes]
logits = slogits.gather(1, undo_sort)                                   # [bs, seqlen*n_hashes]
logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1))  # [bs, n_hashes, seqlen, 1]

# average probabilites across hash rounds (dim 1) and get weighted attention
probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True)) # [bs, n_rounds, seqlen, 1]
out = torch.sum(o * probs, dim=1)                                        # [bs, seqlen, model_dim]

# return output and bucket distribution
return out, buckets

qk = torch.randn(64, 512, 128)
v = torch.rand(64, 512, 128)

lsh_att = LSHAttention()
out, buckets = lsh_att(t, v)
out.shape, buckets.shape

(torch.Size([64, 512, 128]), torch.Size([64, 4096]))