from fastai.text.all import *
from reformer_fastai.all import *

Experiment Tracking

Make sure you have wandb and are logged in.

Load Experiment Tracking with Weights & Biases:

import wandb
from reformer_fastai.tracking import WandbCallback

WANDB_NAME = 'lm_enwik8_base_af'
NOTES = 'Baseline Transformer LM on enwik8 sl 4096'
TAGS =['lm','test','enwik8']
#            name=WANDB_NAME, group=GROUP, notes=NOTES,  tags=TAGS) # config=CONFIG,

Download and Unpack enwik8 Data

Download and unzip enwik8 data

path = untar_data('', dest='/data')

Prepare Data

df = pd.DataFrame({'text':read_lines(path)})
0 <mediawiki xmlns="" xmlns:xsi="" xsi:schemaLocation="" version="0.3" xml:lang="en">\n
1 <siteinfo>\n
2 <sitename>Wikipedia</sitename>\n
3 <base></base>\n
4 <generator>MediaWiki 1.6alpha</generator>\n
btt = ByteTextTokenizer(is_lm=True, add_bos=False, add_eos=False)
df['toks'] = df['text'].apply(btt)
df['lens'] = df['toks'].apply(len)
df['lens_cum_sum'] = df.lens.cumsum()
CPU times: user 2min 24s, sys: 2.93 s, total: 2min 27s
Wall time: 2min 26s
train_cutoff = df.lens.sum() - 10_000_000  # keep all but 10M characters for val and test
train_idxs = df.loc[df['lens_cum_sum'] < train_cutoff].index.values
train_idxs = list(range(0, max(train_idxs)))

remaining_idxs = len(df) - max(train_idxs)
validation_idxs = list(range(max(train_idxs), max(train_idxs) + int(remaining_idxs/2)))
test_idxs = list(range(max(validation_idxs), len(df)))

splits = [train_idxs, validation_idxs]
tfms = [attrgetter("text"), btt]
dsets = Datasets(df, [tfms], splits=splits, dl_type=LMDataLoader)
CPU times: user 1.1 s, sys: 40.1 ms, total: 1.14 s
Wall time: 1.14 s
bs, sl = 2, 4096
# pad_seq2seq = partial(pad_input, pad_idx=bte.pad_token_id, pad_fields=[0,1])
dl_kwargs = [{'lens':df['lens'].values[train_idxs]},
dls = dsets.dataloaders(bs=bs, val_bs=2*bs, seq_len=sl, dl_kwargs=dl_kwargs, shuffle_train=True, n_workers=2)
CPU times: user 31.9 s, sys: 1.19 s, total: 33.1 s
Wall time: 32.7 s
text text_
text text_
0 == External links ==\n <revision>\n* [ &quot;The Goths in Greater Poland&quot; by Tadeusz Makiewicz]\n\n <username>RuM</username>\n*Audio interview on [[Australian]] radio station [[JJJ]], on [[January 26]], [[2006]] : [ MP3 Link].\n[[Category:Obsolete list of encyclopedia topics]]</text>\n&lt;/tr&gt;\n\n[[Category:Sibling duos|Ertegun brothers]]\nIn [[1920]], Russell travelled to [[Russia]] as part of an official delegation sent by the British government to investigate the effects of the [[Russian Revolution of 1917|Russian Revolution]]. Russell's lover [[Dora Black]] also visited Russia independently at the same time - she was enthusiastic about the revolution, but Russell's experiences destroyed his previous tentative support for it.\n* [[kaffir lime]] leaves\nThe Commission negotiates international [[trade]] agreements (in the [[World Trade Organization]]) and other international agreements on behalf of the EU. It closely co-operates in this with the [[Council of the European Union]].\n <page>\nSpink, Walter M. "The Achievement of Ajanta," ''The Age of the Vakatakas'', ed. Shastri,
1 found in Obadiah 10-21 which [[Book of Jeremiah|Jeremiah]] does not quote, and which, had he had it laid out before him, would have suited his purpose admirably. Despite everything, however, there are a number scholars who support both dates and even some who support dates other than the two major possibilities presented. Therefore, any date for the composition Obadiah must be held tentatively.\n\n\n\n <revision>\n <id>8092</id>\n <id>41285331</id>\n\n\n|input = Two 4-panel [[dance pad]]s, six buttons\nThe term '''blue biotechnology''' has also been used to describe the marine and aquatic applications of biotechnology, but its use is relatively rare.\n[[Category:Armed leftist groups]]\n| bgcolor=&quot;#ffd000&quot; | [[Fermium|Fm]]&lt;br/&gt;1.3\n[[Image:columnchromatography.gif|thumb|right|400px|A picture of a standard column chromatography and a flash column chromatography setup]]\n </contributor>\n: [[Central America]] and the Caribbean\n===Written accents===\n| align=&quot;RIGHT&quot; nowrap | 360 MYA\n&lt;br&gt;''domestic:''\n\n:The Father uncreate, the Son uncreate : and the Holy
vocab_sz = btt.vocab_size
xb, yb = dls.one_batch()
xb.shape, yb.shape
(torch.Size([2, 4096]), torch.Size([2, 4096]))


Initialise wandb logging, pleaes do not change project or entity (that that everything gets logged to the same place)

wandb.init(reinit=True, project="reformer-fastai", entity="fastai_community", 
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)
config = TransformerLMConfigEnwik8(max_seq_len=4096, axial_shape=(64,64))
learn = Learner(dls, TransformerLM.from_config(config),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                cbs = [GradientAccumulation(n_acc=8), GradientClip()],
                metrics=[accuracy, perplexity, bpc]).to_fp16(), cbs=WandbCallback(log_model=False, log_preds=False))
epoch train_loss valid_loss accuracy perplexity bpc time
0 1.280827 1.160497 0.657974 3.191520 1.674244 3:24:15