from fastai.text.all import *
from torch.utils.data import Dataset
from reformer_fastai.reformer import LSHLM
Copied from synthetic task notebook. Consider adding to core
.
class TwinSequence(Dataset):
def __init__(self, sl=1024, len=100):
assert sl%2 == 0
self.sl = sl
self.len = len
def __getitem__(self, idx):
seq = torch.randint(1,128,(self.sl//2,)) # w: [1-127] of len sl//2
seq[0] = 0 # seq = 0w
seq = torch.cat((seq,seq), -1) # seq = 0w0w
target = torch.cat((seq[1:],torch.tensor([0])), -1) # return offset target x:[0123], y:[1230]
return (seq, target)
def __len__(self):
return self.len
class MaskTargCallback(Callback):
def before_batch(self):
self.y[:, :self.dls.train_ds.sl//2] = -100
def masked_accuracy(inp, targ, ignore=-100):
pred = inp.argmax(dim=-1)
mask = targ[0] != ignore
return (pred[:,mask] == targ[:,mask]).float().mean()
Sequence length of 1024, similar what is used in the paper.
bs, sl = 32,1024
n_epochs = 5
train_sz = 50_000
valid_sz = 10_000
dls = DataLoaders.from_dsets(TwinSequence(sl, train_sz), TwinSequence(sl, valid_sz), bs=bs, shuffle=False, device='cuda')
Total training steps:
len(dls.train)*n_epochs
We set n_hashes
to 1. The paper reach around 80% masked accuracy with this setup. Other hyperparameters aren't specified in the paper.
n_hashes=1
bucket_size = 64
assert sl % (bucket_size * 2) == 0
model = LSHLM(vocab_sz=128, d_model=256, n_layers=1, n_heads=4, max_seq_len=sl,
bucket_size=bucket_size, n_hashes=n_hashes, causal=True)
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(ignore_index=-100),
metrics=masked_accuracy, cbs=[MaskTargCallback()])
learn.fit_one_cycle(n_epochs, 1e-3)
learn.recorder.plot_loss()
del learn
torch.cuda.empty_cache()
As expected per the paper results, the model struggles with learning the task.
Let's try our ReformerLM
from reformer_fastai.reformer import ReformerLM as fastReformerLM
n_hashes=1
bucket_size = 64
assert sl % (bucket_size * 2) == 0
model = fastReformerLM(128, 256, d_ff=256, n_layers=1, n_heads=4, max_seq_len=sl,attn_dropout=0,
ff_dropout=0, bucket_size=bucket_size, n_hashes=n_hashes, causal=True)
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(ignore_index=-100),
metrics=masked_accuracy, cbs=[MaskTargCallback()])
learn.fit_one_cycle(n_epochs, 1e-3)
learn.recorder.plot_loss()
del learn
torch.cuda.empty_cache()
Similar performance as
LSHLM
.
Testing lucidrains implementation.
from reformer_pytorch import ReformerLM
n_hashes=1
bucket_size = 64
assert sl % (bucket_size * 2) == 0
model = ReformerLM(num_tokens=128, dim=256, depth=1, max_seq_len = sl, heads=4, bucket_size=bucket_size,
n_hashes=n_hashes, causal=True)
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(ignore_index=-100),
metrics=masked_accuracy, cbs=[MaskTargCallback()])
learn.fit_one_cycle(n_epochs, 1e-3)
learn.recorder.plot_loss()
del learn
torch.cuda.empty_cache()
This implementation has similar performance to our
LSHLM