from fastai.text.all import *
from reformer_fastai.all import *
from my_timesaver_utils.all import MyProfileCallback
from pt_memprofile.core import MemStatsCallback
from datasets import load_dataset
dataset = load_dataset('tiny_shakespeare')
train_ds = dataset['train']
train_ds = train_ds.map(splitlines, batched=True, remove_columns=['text'])
train_ds = train_ds.filter(lambda x: x['line'] != '')
df = train_ds.data.to_pandas()
bte = ByteTextTokenizer(is_lm=True, add_bos=True, add_eos=True)
vocab_sz = bte.vocab_size
cut = int(len(df)*0.8)
splits = range_of(df)[:cut], range_of(df[cut:])
tfms = [attrgetter("line"), bte]
dsets = Datasets(df, [tfms, tfms], splits=splits, dl_type=LMDataLoader)
To highlight the Reformer advantages we use 12-layer deep models and large length input sequences.
d_model = 512
n_layers = 12
bs, sl = 1, 4096
pad_seq2seq = partial(pad_input, pad_idx=bte.pad_token_id, pad_fields=[0,1])
dls = dsets.dataloaders(bs=bs, seq_len=sl, before_batch=pad_seq2seq)
learn = Learner(dls, TransformerLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
metrics=[accuracy, perplexity, bpc]).to_fp16()
total_params(learn.model)
learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
learn.mem_stats.plot()
learn.to_my_profile()
learn.fit(1)
learn.my_profile.print_stats()
learn = Learner(dls, ReversibleLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
metrics=[accuracy, perplexity, bpc]).to_fp16()
total_params(learn.model)
learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
learn.mem_stats.plot()
learn.to_my_profile()
learn.fit(1)
learn.my_profile.print_stats()
learn = Learner(dls, ReversibleLM(vocab_sz, 512, n_layers=n_layers, max_seq_len=sl, rev_thres=sl+1),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
metrics=[accuracy, perplexity, bpc]).to_fp16()
total_params(learn.model)
learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
learn.mem_stats.plot()
learn.to_my_profile()
learn.fit(1)
learn.my_profile.print_stats()
learn = Learner(dls, LSHLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
metrics=[accuracy, perplexity, bpc],
cbs=PadBatchCallback()).to_fp16()
total_params(learn.model)
learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
learn.mem_stats.plot()
learn.to_my_profile()
learn.fit(1)
learn.my_profile.print_stats()
learn = Learner(dls, ReformerLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
metrics=[accuracy, perplexity, bpc],
cbs=PadBatchCallback()).to_fp16()
total_params(learn.model)
learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
learn.mem_stats.plot()
learn.to_my_profile()
learn.fit(1)
learn.my_profile.print_stats()
learn = Learner(dls, ReformerLM(vocab_sz, d_model, n_layers=n_layers, max_seq_len=sl, n_hashes=2),
loss_func=CrossEntropyLossFlat(), opt_func=adafactor,
metrics=[accuracy, perplexity, bpc],
cbs=PadBatchCallback()).to_fp16()
total_params(learn.model)
learn = learn.add_cb(MemStatsCallback())
learn.fit(1, cbs=ShortEpochCallback())
learn.mem_stats.plot()
learn.to_my_profile()
learn.fit(1)
learn.my_profile.print_stats()