from fastai.text.all import *
from reformer_fastai.expscript import get_twin_sequence_dataloaders, get_lshlm_model, get_synthetic_learner
from reformer_fastai.data import MaskTargCallback
Settings from experiement. Note: LSH-models run with seed 42. Full attention with 1234 (didn't converge in 42).
bs=64
sl=1024
train_sz=12800
valid_sz=1280
n_epochs=750
seed=42
dls = get_twin_sequence_dataloaders(bs=bs, sl=sl, train_sz=train_sz, valid_sz=valid_sz, seed=seed)
n_hashes=1
bucket_size=64 # suggested in trax
vocab_sz=128 # specific for the synthetic task
d_model=256
n_layers=1 # specified in paper
n_heads=4
d_ff=256
attn_dropout=0
ff_dropout=0
emb_dropout=0
max_seq_len=sl
causal=True
use_lsh=True
def load_learner(n_hashes, use_lsh, fn, d_ff):
model = get_lshlm_model(vocab_sz=vocab_sz, d_model=d_model, n_layers=n_layers, n_heads=n_heads,
max_seq_len=max_seq_len, bucket_size=bucket_size, n_hashes=n_hashes, causal=causal,
use_lsh=use_lsh, seed=seed, attn_dropout=attn_dropout, ff_dropout=ff_dropout, emb_dropout=emb_dropout,
d_ff=d_ff)
learn = get_synthetic_learner(dls, model)
learn = learn.load(fn)
return learn
learn_lsh1 = load_learner(n_hashes=1, use_lsh=True, fn=fn1, d_ff=256)
learn_lsh2 = load_learner(n_hashes=2, use_lsh=True, fn=fn2, d_ff=256)
learn_lsh4 = load_learner(n_hashes=4, use_lsh=True, fn=fn4, d_ff=256)
learn_lsh1.model.n_hashes, learn_lsh2.model.n_hashes, learn_lsh4.model.n_hashes
This model was trained with a different seed, since it did't converge with the one used for the LSH-models. Note that n_hashes
=6 is set in config, but is not used when use_lsh
=False.
seed=1234
dls = get_twin_sequence_dataloaders(bs=bs, sl=sl, train_sz=train_sz, valid_sz=valid_sz, seed=seed)
learn_full = load_learner(n_hashes=6, use_lsh=False, fn=fn_full, d_ff=256)
learn_full.model.use_lsh
Validate lsh-models with n_hashes=1,2,4,8
res = []
for learner in [learn_lsh4, learn_lsh2, learn_lsh1]:
train_hashes = learner.model.n_hashes
for eval_hashes in [8,4,2,1]:
learner.model.n_hashes=eval_hashes
_, m_acc = learner.validate(cbs=MaskTargCallback)
res.append((f'LSH-{train_hashes}', f'LSH-{eval_hashes}', m_acc))
learner.model.n_hashes=train_hashes #reset n_hashes
Evaluate LSH-models with full-attention:
for learner in [learn_lsh4, learn_lsh2, learn_lsh1]:
learner.model.use_lsh=False
_, m_acc = learner.validate(cbs=MaskTargCallback)
res.append((f'LSH-{learner.model.n_hashes}', 'Full Attention', m_acc))
learner.model.use_lsh=True #reset
Validate model trained with full attention with full attention and LSH
_,m_acc = learn_full.validate(cbs=MaskTargCallback)
res.append(('Full Attention','Full Attention', m_acc))
#validate with lsh-1,2,4,8
learn_full.model.use_lsh=True
for n_hashes in [8,4,2,1]:
learn_full.model.n_hashes=n_hashes
_, m_acc = learn_full.validate(cbs=MaskTargCallback)
res.append(('Full Attention', f'LSH-{learn_full.model.n_hashes}', m_acc))
cols = ['Train', 'Eval', 'Masked_Accuracy']
df = pd.DataFrame(res, columns=cols)
df['Masked_Accuracy'] = df['Masked_Accuracy'].round(4)*100
df = df.pivot_table(index=cols[0],
columns=cols[1],
values=cols[2])
df = df.iloc[[0,3,2,1], [0,4,3,2,1]]
df