from fastai.text.all import *
from reformer_fastai.all import *
Make sure you have wandb and are logged in
Load Experiment Tracking with Weights & Biases:
import wandb
from reformer_fastai.tracking import WandbCallback
WANDB_NAME = 'enc_lm_enwik8_reversible_af'
GROUP = 'TEST'
NOTES = 'ReversibleLM on enwik8 sl 4096'
CONFIG = {}
TAGS = ['lm','rev','enwik8']
path = untar_data('http://mattmahoney.net/dc/enwik8.zip', dest='./data')
df = pd.DataFrame({'text':read_lines(path)})
df.head()
btt = ByteTextTokenizer(is_lm=True, add_bos=False, add_eos=False)
%%time
df['toks'] = df['text'].apply(btt)
df['lens'] = df['toks'].apply(len)
df['lens_cum_sum'] = df.lens.cumsum()
train_cutoff = df.lens.sum() - 10_000_000 # keep all but 10M characters for val and test
train_idxs = df.loc[df['lens_cum_sum'] < train_cutoff].index.values
train_idxs = list(range(0, max(train_idxs)))
remaining_idxs = len(df) - max(train_idxs)
validation_idxs = list(range(max(train_idxs), max(train_idxs) + int(remaining_idxs/2)))
test_idxs = list(range(max(validation_idxs), len(df)))
splits = [train_idxs, validation_idxs]
tfms = [attrgetter("text"), btt]
dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)
%%time
bs, sl = 2, 4096
# pad_seq2seq = partial(pad_input, pad_idx=bte.pad_token_id, pad_fields=[0,1])
dl_kwargs = [{'lens':df['lens'].values[train_idxs]},
{'val_lens':df['lens'].values[validation_idxs]}]
dls = dsets.dataloaders(bs=bs, val_bs=2*bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True, n_workers=2)
dls.show_batch(max_n=2)
vocab_sz = btt.vocab_size
xb, yb = dls.one_batch()
xb.shape, yb.shape
Initialise wandb logging, pleaes do not change project
or entity
(that that everything gets logged to the same place)
wandb.init(reinit=True, project="reformer-fastai", entity="fastai_community",
name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)
opt_func = adafactor
learn = Learner(dls, ReversibleLM(vocab_sz, 1024, n_layers=3, max_seq_len=sl, rev_thres=4097), #using irrev blocks for speed
loss_func=CrossEntropyLossFlat(), opt_func=opt_func,
metrics=[accuracy, perplexity, bpc],
cbs = [GradientAccumulation(n_acc=8), GradientClip(), TerminateOnNaNCallback()]).to_fp16()
learn.fit(1, cbs=WandbCallback(log_model=False, log_preds=False))
learn.recorder.plot_loss()