Fit deep models using very long sequences
n_epochs = 4
bs = 1
sl = 2**14
n_layers = 3
seed = 2
Make sure you have wandb and are logged in:
Load Experiment Tracking with Weights & Biases:
import wandb
WANDB_NAME = f'n_layers-{n_layers}_enwik8_sl-{sl}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
GROUP = 'TEST'
NOTES = 'ReformerLM on enwik8 sl=32k'
CONFIG = {}
TAGS = ['lm','reformer','enwik8', 'test']
path = untar_data('http://mattmahoney.net/dc/enwik8.zip', dest='/data')
df = pd.DataFrame({'text':read_lines(path)})
df.head()
btt = ByteTextTokenizer(is_lm=True, add_bos=False, add_eos=False)
%%time
df['toks'] = df['text'].apply(btt)
df['lens'] = df['toks'].apply(len)
df['lens_cum_sum'] = df.lens.cumsum()
train_cutoff = df.lens.sum() - 10_000_000 # keep all but 10M characters for val and test
train_idxs = df.loc[df['lens_cum_sum'] < train_cutoff].index.values
train_idxs = list(range(0, max(train_idxs)))
remaining_idxs = len(df) - max(train_idxs)
validation_idxs = list(range(max(train_idxs), max(train_idxs) + int(remaining_idxs/2)))
test_idxs = list(range(max(validation_idxs), len(df)))
splits = [train_idxs, validation_idxs]
tfms = [attrgetter("text"), btt]
dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)
%%time
dl_kwargs = [{'lens':df['lens'].values[train_idxs]},
{'val_lens':df['lens'].values[validation_idxs]}]
dls = dsets.dataloaders(bs=bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True, n_workers=2)
dls.show_batch()
vocab_sz = btt.vocab_size
xb, yb = dls.one_batch()
xb.shape, yb.shape
wandb.init(reinit=True, project="reformer-fastai", entity="fastai_community",
name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)
pad_id = btt.pad_token_id
config = NLayersConfig(n_layers=n_layers, max_seq_len=sl, pad_idx=pad_id, seed=seed)
config
learn = Learner(dls, ReformerLM.from_config(config),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
cbs = [GradientAccumulation(n_acc=8), GradientClip(1.0),
PadBatchCallback(bucket_size=config.bucket_size,
val=pad_id, y_val=pad_id)],
metrics=[accuracy, perplexity, bpc])
learn.fit(n_epochs, cbs=WandbCallback(log_model=False, log_preds=False))
learn.recorder.plot_loss()