Open In Colab

from fastai.text.all import *
from reformer_fastai.all import *

from my_timesaver_utils.all import MyProfileCallback
from pt_memprofile.core import MemStatsCallback

Prepare data

from datasets import load_dataset
dataset = load_dataset('tiny_shakespeare')
train_ds = dataset['train']
train_ds = train_ds.map(splitlines, batched=True, remove_columns=['text'])
train_ds = train_ds.filter(lambda x: x['line'] != '')
df = train_ds.data.to_pandas()
bte = ByteTextTokenizer(is_lm=True, add_bos=True, add_eos=True)
vocab_sz = bte.vocab_size
cut = int(len(df)*0.8)
splits = range_of(df)[:cut], range_of(df[cut:])
tfms = [attrgetter("line"), bte]
dsets = Datasets(df, [tfms, tfms], splits=splits, dl_type=LMDataLoader)

To highlight the Reformer advantages we use 12-layer deep models and large length input sequences.

d_model = 512
n_layers = 12
bs, sl = 1, 4096
pad_seq2seq = partial(pad_input, pad_idx=bte.pad_token_id, pad_fields=[0,1])

dls = dsets.dataloaders(bs=bs, seq_len=sl, before_batch=pad_seq2seq)

TransformerLM

learn = Learner(dls, TransformerLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                metrics=[accuracy, perplexity, bpc]).to_fp16()
total_params(learn.model)
(40034051, True)

Memory

learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss accuracy perplexity bpc time
0 00:12
learn.mem_stats.plot()

Timing

learn.to_my_profile()
learn.fit(1)
epoch train_loss valid_loss accuracy perplexity bpc time
0 3.395151 3.352084 0.145182 28.562201 4.836035 07:37
learn.my_profile.print_stats()
fit  called 1 times. max: 457.169 avg: 457.169
   epoch  called 1 times. max: 457.160 avg: 457.160
      train  called 1 times. max: 394.421 avg: 394.421
         train_batch  called 202 times. max: 1.904 avg: 1.844
            train_pred  called 202 times. max: 1.256 avg: 1.204
            train_loss  called 202 times. max: 0.002 avg: 0.002
            train_backward  called 202 times. max: 0.553 avg: 0.550
            train_step  called 202 times. max: 0.117 avg: 0.085
            train_zero_grad  called 202 times. max: 0.006 avg: 0.004
      valid  called 1 times. max: 62.737 avg: 62.737
         valid_batch  called 48 times. max: 1.243 avg: 1.180
            valid_pred  called 48 times. max: 1.242 avg: 1.179
            valid_loss  called 48 times. max: 0.002 avg: 0.001

ReversibleLM

learn = Learner(dls, ReversibleLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                metrics=[accuracy, perplexity, bpc]).to_fp16()
total_params(learn.model)
(40035075, True)

Memory

learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss accuracy perplexity bpc time
0 00:19
learn.mem_stats.plot()

Timing

learn.to_my_profile()
learn.fit(1)
epoch train_loss valid_loss accuracy perplexity bpc time
0 2.489643 2.441111 0.273004 11.485793 3.521779 10:16
learn.my_profile.print_stats()
fit  called 2 times. max: 616.622 avg: 536.896
   epoch  called 2 times. max: 616.611 avg: 536.886
      train  called 2 times. max: 553.057 avg: 473.739
         train_batch  called 404 times. max: 2.766 avg: 2.235
            train_pred  called 404 times. max: 1.330 avg: 1.214
            train_loss  called 404 times. max: 0.002 avg: 0.001
            train_backward  called 404 times. max: 1.406 avg: 0.930
            train_step  called 404 times. max: 0.117 avg: 0.087
            train_zero_grad  called 404 times. max: 0.006 avg: 0.004
      valid  called 2 times. max: 63.552 avg: 63.144
         valid_batch  called 96 times. max: 1.258 avg: 1.188
            valid_pred  called 96 times. max: 1.257 avg: 1.187
            valid_loss  called 96 times. max: 0.002 avg: 0.001

Using Irreversible blocks

learn = Learner(dls, ReversibleLM(vocab_sz, 512, n_layers=n_layers, max_seq_len=sl, rev_thres=sl+1),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                metrics=[accuracy, perplexity, bpc]).to_fp16()
total_params(learn.model)
(40035075, True)

Momory

learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss accuracy perplexity bpc time
0 00:12
learn.mem_stats.plot()

Timing

learn.to_my_profile()
learn.fit(1)
epoch train_loss valid_loss accuracy perplexity bpc time
0 2.487123 2.461037 0.259202 11.716952 3.550525 07:38
learn.my_profile.print_stats()
fit  called 3 times. max: 616.622 avg: 510.696
   epoch  called 3 times. max: 616.611 avg: 510.684
      train  called 3 times. max: 553.057 avg: 447.661
         train_batch  called 606 times. max: 2.766 avg: 2.106
            train_pred  called 606 times. max: 1.330 avg: 1.210
            train_loss  called 606 times. max: 0.002 avg: 0.001
            train_backward  called 606 times. max: 1.406 avg: 0.804
            train_step  called 606 times. max: 0.117 avg: 0.087
            train_zero_grad  called 606 times. max: 0.006 avg: 0.004
      valid  called 3 times. max: 63.552 avg: 63.021
         valid_batch  called 144 times. max: 1.258 avg: 1.186
            valid_pred  called 144 times. max: 1.257 avg: 1.185
            valid_loss  called 144 times. max: 0.002 avg: 0.001

LSHLM

learn = Learner(dls, LSHLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                metrics=[accuracy, perplexity, bpc],
                cbs=PadBatchCallback()).to_fp16()
total_params(learn.model)
(36888323, True)

Memory

learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss accuracy perplexity bpc time
0 00:10
learn.mem_stats.plot()

Timing

learn.to_my_profile()
learn.fit(1)
epoch train_loss valid_loss accuracy perplexity bpc time
0 3.396859 3.371909 0.145092 29.134102 4.864637 06:59
learn.my_profile.print_stats()
fit  called 4 times. max: 616.622 avg: 487.851
   epoch  called 4 times. max: 616.611 avg: 487.841
      train  called 4 times. max: 553.057 avg: 427.078
         train_batch  called 808 times. max: 2.766 avg: 2.004
            train_pred  called 808 times. max: 1.330 avg: 1.166
            train_loss  called 808 times. max: 0.009 avg: 0.003
            train_backward  called 808 times. max: 1.406 avg: 0.744
            train_step  called 808 times. max: 0.119 avg: 0.087
            train_zero_grad  called 808 times. max: 0.006 avg: 0.004
      valid  called 4 times. max: 63.552 avg: 60.761
         valid_batch  called 192 times. max: 1.258 avg: 1.140
            valid_pred  called 192 times. max: 1.257 avg: 1.139
            valid_loss  called 192 times. max: 0.007 avg: 0.001

ReformerLM

learn = Learner(dls, ReformerLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                metrics=[accuracy, perplexity, bpc],
                cbs=PadBatchCallback()).to_fp16()
total_params(learn.model)
(36888323, True)

Memory

learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss accuracy perplexity bpc time
0 00:14
learn.mem_stats.plot()

Timing

learn.to_my_profile()
learn.fit(1)
epoch train_loss valid_loss accuracy perplexity bpc time
0 3.374069 2.589408 0.257878 13.321885 3.735726 10:03
learn.my_profile.print_stats()
fit  called 5 times. max: 616.622 avg: 510.974
   epoch  called 5 times. max: 616.611 avg: 510.964
      train  called 5 times. max: 553.057 avg: 451.563
         train_batch  called 1010 times. max: 2.766 avg: 2.125
            train_pred  called 1010 times. max: 1.330 avg: 1.139
            train_loss  called 1010 times. max: 0.009 avg: 0.002
            train_backward  called 1010 times. max: 1.513 avg: 0.893
            train_step  called 1010 times. max: 0.119 avg: 0.087
            train_zero_grad  called 1010 times. max: 0.007 avg: 0.004
      valid  called 5 times. max: 63.552 avg: 59.398
         valid_batch  called 240 times. max: 1.258 avg: 1.113
            valid_pred  called 240 times. max: 1.257 avg: 1.112
            valid_loss  called 240 times. max: 0.007 avg: 0.001

2 hashing rounds

learn = Learner(dls, ReformerLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl, n_hashes=2),
                loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
                metrics=[accuracy, perplexity, bpc],
                cbs=PadBatchCallback()).to_fp16()
total_params(learn.model)
(36888323, True)

Memory

learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss accuracy perplexity bpc time
0 00:07
learn.mem_stats.plot()

Timing

learn.to_my_profile()
learn.fit(1)
epoch train_loss valid_loss accuracy perplexity bpc time
0 3.185392 2.701723 0.278278 14.905396 3.897763 04:34
learn.my_profile.print_stats()
fit  called 6 times. max: 616.622 avg: 471.562
   epoch  called 6 times. max: 616.611 avg: 471.551
      train  called 6 times. max: 553.057 avg: 418.046
         train_batch  called 1212 times. max: 2.766 avg: 1.959
            train_pred  called 1212 times. max: 1.330 avg: 1.014
            train_loss  called 1212 times. max: 0.009 avg: 0.002
            train_backward  called 1212 times. max: 1.513 avg: 0.853
            train_step  called 1212 times. max: 0.119 avg: 0.086
            train_zero_grad  called 1212 times. max: 0.007 avg: 0.004
      valid  called 6 times. max: 63.552 avg: 53.503
         valid_batch  called 288 times. max: 1.258 avg: 0.992
            valid_pred  called 288 times. max: 1.257 avg: 0.991
            valid_loss  called 288 times. max: 0.007 avg: 0.001