import sys
import multiprocessing
from fastcore.all import *
from fastai.basics import *
from fastai.text.all import *
from fastai.distributed import *
from reformer_fastai.all import *
Overview
The functions used for the entire datapipline, all model loading and model training can be found here. Click on the "[source]" links to see the full code. Full source code for the experiment script itself can be seen below.
Running the script
Experiments are run with this script by specifying:
1) The task to run, i.e. synthetic task, language modelling or translation 2) (Optionally) override default parameters for the dataloaders, models, training loop and logging
To run the training script run run_exp
from within the reformer_fastai
repo. For example:
run_exp 'synth' lr=1e-4 bs=32
To run experiment script on multiple GPUs use fastai.launch
:
python -m fastai.launch [--gpus 1,2] expscript.py [args]
Experiment Configs
All model and training hypermaramters used in training can be found in Experiments/Configs
Training Script
Command line arguments
Only arguments used to alternate between different experiment runs will be passed to the model from the command line, e.g. for the Synthetic experiment, only n_hashes
and use_lsh
can be changed from the command line. All other model parameters are fixed from SyntheticConfig
@call_parse
def run_exp(task:Param(help="Task options: 'synt','lm_base','lm_rev',lm_shared_qk, n_hashes, n_layers, wmt_rev, wmt_base", type=str),
data_path:Param(help="Path to data folder", type=str, default='./data'),
n_epochs:Param(help="Number of epochs", type=int, default=1),
lr:Param(help="Learning rate", type=float, default=1e-3),
bs:Param(help="Batch size", type=int, default=64),
train_sz:Param(help="TwinSequence train size", type=int, default=12800),
valid_sz:Param(help="TwinSequence valid size", type=int, default=1280),
n_layers:Param(help="Number of layers", type=int, default=3),
n_hashes:Param(help="Number of LSH Attention hashes", type=int, default=1),
use_lsh:Param(help="Use LSH Attention", type=bool_arg, default=False),
max_seq_len:Param(help="Max sequence length for model embedding and dataloader", type=int, default=2048),
do_wandb_logging:Param(help="Use weights and biases logging", type=bool_arg, default=False),
run_name:Param(help="Run name for wandb tracking and model filename", type=str, default=''),
wandb_group:Param(help="wandb group", type=str, default='TEST'),
wandb_notes:Param(help="wandb notes", type=str, default='My experiment notes'),
wandb_tags:Param(help="wandb tags, add tags in a single string, space separated", type=str, default='test'),
save_model:Param(help="Save model locally in /models", type=bool_arg, default=False),
grad_accum:Param(help="Gradient Accumulation, set greater than 1 to implement", type=int, default=1),
clip:Param(help="Gradient Clipping, will be set if > 0.0", type=float, default=0.0),
cuda_id:Param(help="Which cuda device to use", type=int, default=0),
seed:Param(help="Set seed for reproducibiltiy, passing anything except 0 will use fastai's set_seed", type=int, default=0),
distrib:Param(help="Set to True if using distributed training", type=bool_arg, default=False),
verbose:Param(help="Print script logs", type=bool_arg, default=True),
tiny:Param(help="Use 5% of data, for quick iteration and testings", type=bool_arg, default=False),
precision:Param(help="0:fp16, 1:non native fp16, 2:fp32", type=int, default=0)
):
"""Task options: 'synt','lm_base','lm_rev',lm_shared_qk, trans"""
#Set up distributed training
# _wrapper = rank0_first if distrib else partial
# if distrib: cuda_id = None
torch.cuda.set_device(cuda_id)
# Callbacks used for training
cbs = []
if save_model: cbs.append(SaveModelCallback(every_epoch=True, with_opt=True))
#random seeds
if seed!=0:
set_seed(seed, reproducible=True) # this sets `torch.cudnn.backends ++`
else:
seed = None # this is passed to LSH and data generator. They expect None or int
if task == 'synt':
"Model + Data Args than can be changed from command line: train_sz, valid_sz, n_hashes, use_lsh, seed"
if run_name == '':
if use_lsh: run_name = f'{task}_lsh-{n_hashes}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
else: run_name = f'{task}_full-attn_bs-{bs}_n_eps-{n_epochs}'
print('Getting model ...')
config = SyntheticConfig(warn=False, verbose=verbose, n_hashes=n_hashes, use_lsh=use_lsh, seed=seed)
if verbose: print(config)
config.save(run_name, add_tstmp=True)
model = LSHLM.from_config(config)
print('done!')
print('Getting dataloaders ...')
if train_sz != 12800: print(f'Note, "train_sz" changed from recommended 12800 to {train_sz}')
dls = get_twin_sequence_dataloaders(bs=bs, sl=config['max_seq_len'], train_sz=train_sz,
valid_sz=valid_sz, seed=seed)
print('done!')
print('Getting learner ...')
learn = get_synthetic_learner(dls, model, precision)
print('done!')
# Set up Weights & Biases logging, if needed
if do_wandb_logging and rank_distrib()==0:
wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
# Append training callbacks needed
cbs.append(MaskTargCallback())
# Start training
print('Starting training...')
learn.fit_one_cycle(n_epochs, lr, cbs=cbs)
print('done!')
# Close wandb logging for this run
if do_wandb_logging: wandb_run.finish()
# Save model weights if needed, saved in /models relative to where script is run
if save_model:
now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
learn.save(f'{task}_{run_name}_{now}')
elif 'lm' in task:
"Model args that can be changed from command line: axial_shape, max_seq_len"
axial_shape = get_axial_shape(max_seq_len)
if task == 'lm_base':
if run_name == '': run_name = f'{task}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
config = TransformerLMConfigEnwik8(warn=False, verbose=verbose,
axial_shape=axial_shape, max_seq_len=max_seq_len)
print('Getting model ...')
model = TransformerLM.from_config(config)
print('done!')
elif task == 'lm_rev':
if run_name == '': run_name = f'{task}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
config = ReversibleLMConfigEnwik8(warn=False, verbose=verbose,
axial_shape=axial_shape, max_seq_len=max_seq_len)
print('Getting model ...')
model = ReversibleLM.from_config(config)
print('done!')
elif task == 'lm_shared_qk':
if run_name == '': run_name = f'{task}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
config = TransformerLMConfigEnwik8(warn=False, verbose=verbose, shared_qk=True,
axial_shape=axial_shape, max_seq_len=max_seq_len)
print('Getting model ...')
model = TransformerLM.from_config(config)
print('done!')
if verbose: print(config)
config.save(run_name, add_tstmp=True)
print('Checking data')
# _wrapper(download_enwik8_data, data_path=data_path)
# if distrib: rank0_first(download_enwik8_data, data_path=data_path)
download_enwik8_data(data_path=data_path)
print('done')
print('Getting dataloaders ...')
dls = get_enwik8_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len,
verbose=verbose, tiny=tiny)
print('done')
print('Getting learner ...')
learn = get_lm_learner(dls, model, opt_func=adafactor, precision=precision)
print('done!')
# CALLBACKS
## Gradient Clipping
if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
## Gradient Accumulation
if grad_accum > 1:
print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
cbs.append(GradientAccumulation(n_acc=grad_accum))
run_name = run_name + f'_grad-accum-{grad_accum}'
# Set up Weights & Biases logging, if needed
if do_wandb_logging and rank_distrib()==0:
wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
# Start training
print('Starting training...')
learn.fit(n_epochs, cbs=cbs)
print('done!')
# Close wandb logging for this run
if do_wandb_logging: wandb_run.finish()
# Save model weights if needed, saved in /models relative to where script is run
if save_model:
now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
learn.save(f'{task}_{run_name}_{now}')
elif task == 'n_hashes':
"Model args that can be changed from command line: n_hashes, seed"
if run_name == '': run_name = f'{task}-{n_hashes}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
print('Checking data')
# _wrapper(download_enwik8_data, data_path=data_path)
# if distrib: rank0_first(download_enwik8_data, data_path=data_path)
download_enwik8_data(data_path=data_path)
print('done')
print('Getting dataloaders ...')
dls = get_enwik8_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len,
verbose=verbose, tiny=tiny)
print('done')
pad_id = dls.byte_text_tokenizer.pad_token_id
config = NHashesConfig(warn=False, verbose=verbose, n_hashes=n_hashes,
seed=seed, pad_idx=pad_id)
print('Getting model ...')
model = LSHLM.from_config(config)
print('done!')
if verbose: print(config)
config.save(run_name, add_tstmp=True)
print('Getting learner ...')
learn = get_lm_learner(dls, model, opt_func=adafactor, precision=precision)
print('done!')
# CALLBACKS
## Gradient Clipping
if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
## Gradient Accumulation
if grad_accum > 1:
print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
cbs.append(GradientAccumulation(n_acc=grad_accum))
run_name = run_name + f'_grad-accum-{grad_accum}'
#LSH-specific callback
if config.use_lsh: cbs.append(PadBatchCallback(bucket_size=config.bucket_size,
val=pad_id, y_val=pad_id))
# Set up Weights & Biases logging, if needed
if do_wandb_logging and rank_distrib()==0:
wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
# Start training
print('Starting training...')
learn.fit(n_epochs, cbs=cbs)
print('done!')
# Close wandb logging for this run
if do_wandb_logging: wandb_run.finish()
# Save model weights if needed, saved in /models relative to where script is run
if save_model:
now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
learn.save(f'{task}_{run_name}_{now}')
elif task == 'n_layers':
"Model args that can be changed from command line: n_hashes, seed"
if run_name == '': run_name = f'{task}-{n_layers}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
print('Checking data')
# _wrapper(download_enwik8_data, data_path=data_path)
# if distrib: rank0_first(download_enwik8_data, data_path=data_path)
download_enwik8_data(data_path=data_path)
print('done')
print('Getting dataloaders ...')
dls = get_enwik8_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len,
verbose=verbose, tiny=tiny)
print('done')
pad_id = dls.byte_text_tokenizer.pad_token_id
config = NLayersConfig(warn=False, verbose=verbose, n_layers=n_layers,
max_seq_len=max_seq_len, seed=seed, pad_idx=pad_id)
print('Getting model ...')
model = ReformerLM.from_config(config)
print('done!')
if verbose: print(config)
config.save(run_name, add_tstmp=True)
print('Getting learner ...')
learn = get_reformerlm_learner(dls, model, opt_func=adafactor, precision=precision)
print('done!')
# CALLBACKS
## Gradient Clipping
if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
## Gradient Accumulation
if grad_accum > 1:
print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
cbs.append(GradientAccumulation(n_acc=grad_accum))
run_name = run_name + f'_grad-accum-{grad_accum}'
#LSH-specific callback
if config.use_lsh: cbs.append(PadBatchCallback(bucket_size=config.bucket_size,
val=pad_id, y_val=pad_id))
# Set up Weights & Biases logging, if needed
if do_wandb_logging and rank_distrib()==0:
wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
# Start training
print('Starting training...')
learn.fit(n_epochs, cbs=cbs)
print('done!')
# Close wandb logging for this run
if do_wandb_logging: wandb_run.finish()
# Save model weights if needed, saved in /models relative to where script is run
if save_model:
now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
learn.save(f'{task}_{run_name}_{now}')
elif 'wmt' in task:
"Model args that can be changed from command line: n_layers, max_seq_len"
axial_shape = get_axial_shape(max_seq_len)
if run_name == '': run_name = f'{task}_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
print('Checking data')
download_wmt14_data(data_path=data_path)
print('done')
print('Getting dataloaders and tokenizer ...')
dls, tok = get_wmt14_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len,
verbose=verbose, tiny=tiny)
print('done')
print('Getting model ...')
if task == 'wmt_rev':
config = ReversibleTransformerConfigWMT(warn=False, verbose=verbose,
enc_vocab_sz=tok.vocab_size, dec_vocab_sz=tok.vocab_size, pad_idx=tok.PAD_ID,
n_enc_layers=n_layers, n_dec_layers=n_layers)
model = ReversibleTransformer.from_config(config)
elif task == 'wmt_base':
config = TransformerConfigWMT(warn=False, verbose=verbose,
enc_vocab_sz=tok.vocab_size, dec_vocab_sz=tok.vocab_size, pad_idx=tok.PAD_ID,
n_enc_layers=n_layers, n_dec_layers=n_layers)
model = Transformer.from_config(config)
print('done!')
if verbose: print(config)
config.save(run_name, add_tstmp=True)
print('Getting learner ...')
learn = get_seq2seq_learner(dls, model, tok, precision)
print('done!')
# CALLBACKS
cbs += [CombineInputOutputCallback(), LossTargetShiftCallback(), RemoveEOSCallback(eos_idx=tok.EOS_ID)]
## Gradient Clipping Callback
if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
## Gradient Accumulation Callback
if grad_accum > 1:
print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
cbs.append(GradientAccumulation(n_acc=grad_accum))
run_name = run_name + f'_grad-accum-{grad_accum}'
# Set up Weights & Biases logging, if needed
if do_wandb_logging:
wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
# Start training
print('Starting training...')
learn.fit_one_cycle(n_epochs, lr, cbs=cbs)
print('done!')
# Close wandb logging for this run
if do_wandb_logging: wandb_run.finish()
# Save model weights if needed, saved in /models relative to where script is run
if save_model:
now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
learn.save(f'{task}_{run_name}_{now}')
elif task == 'test_cfg':
print('Locals ', locals())
print()
config = SyntheticConfig(verbouse=True, **locals())
print(config)
config.save('test')
config2 = SyntheticConfig.from_file('test')
print(config2)
elif task == 'test':
print('testing testing :)')
print(verbose)
else:
print('No task run')
Running the Script
Example command to run full scale experiment, note than run_name
can be passed, but if not passed it will be automatically constructed based on the task and relevant arguments
Running the Synthetic Experiment:
run_exp 'synt' \
--n_epochs=750 \
--lr=1e-4 \
--bs=128 \
--use_lsh=True \
--n_hashes=1 \
--train_sz=12800 \
--valid_sz=1280 \
--seed=1234 \
--wandb_group='Synthetic' \
--wandb_tags='synthetic_task lsh lm test' \
--run_name='synth_lsh_1_hash'
Running the Reversible Language Model experiment:
For the full 60k steps with a sequence length of 65536, the number of epochs can be calculated as follows:with sl == 2**16, 1 epoch of enwik8 will have 172 batches, therefore; n_epoch == 60000/172 == 349
run_exp 'lm_rev' \ --n_epochs=3 \ --lr=1e-4 \ --bs=8 \ --max_seq_len=4096 \ --do_wandb_logging=True \ --wandb_group='enwik8_lm_rev' \ --wandb_tags='lm_rev lm exp' \ --wandb_notes='This is a test' --grad_accum=4