from fastai.text.all import *
from reformer_fastai.expscript import get_twin_sequence_dataloaders, get_lshlm_model, get_synthetic_learner
from reformer_fastai.data import MaskTargCallback

Get dataset

Settings from experiement. Note: LSH-models run with seed 42. Full attention with 1234 (didn't converge in 42).

bs=64
sl=1024
train_sz=12800
valid_sz=1280
n_epochs=750
seed=42
dls = get_twin_sequence_dataloaders(bs=bs, sl=sl, train_sz=train_sz, valid_sz=valid_sz, seed=seed)

Load lsh-learners

n_hashes=1
bucket_size=64  # suggested in trax
vocab_sz=128    # specific for the synthetic task
d_model=256
n_layers=1      # specified in paper
n_heads=4
d_ff=256

attn_dropout=0
ff_dropout=0
emb_dropout=0

max_seq_len=sl
causal=True
use_lsh=True
def load_learner(n_hashes, use_lsh, fn, d_ff):
    model = get_lshlm_model(vocab_sz=vocab_sz, d_model=d_model, n_layers=n_layers, n_heads=n_heads, 
              max_seq_len=max_seq_len, bucket_size=bucket_size, n_hashes=n_hashes, causal=causal, 
                use_lsh=use_lsh, seed=seed, attn_dropout=attn_dropout, ff_dropout=ff_dropout, emb_dropout=emb_dropout, 
                            d_ff=d_ff)
    learn = get_synthetic_learner(dls, model)
    learn = learn.load(fn)
    return learn
learn_lsh1 = load_learner(n_hashes=1, use_lsh=True, fn=fn1, d_ff=256)
learn_lsh2 = load_learner(n_hashes=2, use_lsh=True, fn=fn2, d_ff=256)
learn_lsh4 = load_learner(n_hashes=4, use_lsh=True, fn=fn4, d_ff=256)
learn_lsh1.model.n_hashes, learn_lsh2.model.n_hashes, learn_lsh4.model.n_hashes
(1, 2, 4)

load full attention with different seed

This model was trained with a different seed, since it did't converge with the one used for the LSH-models. Note that n_hashes=6 is set in config, but is not used when use_lsh=False.

seed=1234
dls = get_twin_sequence_dataloaders(bs=bs, sl=sl, train_sz=train_sz, valid_sz=valid_sz, seed=seed)
learn_full = load_learner(n_hashes=6, use_lsh=False, fn=fn_full, d_ff=256) 
learn_full.model.use_lsh
False

Validate LSH-models with changing n_hashes:

Validate lsh-models with n_hashes=1,2,4,8

res = []
for learner in [learn_lsh4, learn_lsh2, learn_lsh1]:
    train_hashes = learner.model.n_hashes
    
    for eval_hashes in [8,4,2,1]:
        learner.model.n_hashes=eval_hashes
        _, m_acc = learner.validate(cbs=MaskTargCallback)
        res.append((f'LSH-{train_hashes}', f'LSH-{eval_hashes}', m_acc))
    
    learner.model.n_hashes=train_hashes #reset n_hashes

Evaluate LSH-models with full-attention:

for learner in [learn_lsh4, learn_lsh2, learn_lsh1]:
    learner.model.use_lsh=False
    _, m_acc = learner.validate(cbs=MaskTargCallback)
    res.append((f'LSH-{learner.model.n_hashes}', 'Full Attention', m_acc))
    
learner.model.use_lsh=True #reset

Validate full attention model

Validate model trained with full attention with full attention and LSH

_,m_acc = learn_full.validate(cbs=MaskTargCallback)
res.append(('Full Attention','Full Attention', m_acc))

#validate with lsh-1,2,4,8
learn_full.model.use_lsh=True
for n_hashes in [8,4,2,1]:
    learn_full.model.n_hashes=n_hashes
    _, m_acc = learn_full.validate(cbs=MaskTargCallback)
    res.append(('Full Attention', f'LSH-{learn_full.model.n_hashes}', m_acc))

summarize results

cols = ['Train', 'Eval', 'Masked_Accuracy']
df = pd.DataFrame(res, columns=cols)
df['Masked_Accuracy'] = df['Masked_Accuracy'].round(4)*100
df = df.pivot_table(index=cols[0],
                   columns=cols[1],
                   values=cols[2])
df = df.iloc[[0,3,2,1], [0,4,3,2,1]]
df
Eval Full Attention LSH-8 LSH-4 LSH-2 LSH-1
Train
Full Attention 100.00 1.37 1.85 3.00 4.56
LSH-4 46.54 99.71 99.77 93.05 77.62
LSH-2 75.94 96.60 97.45 97.08 86.06
LSH-1 70.65 76.61 79.68 79.34 56.09

table2