The experiment script plus all functions used for the entire datapipline, all model loading and model training are found here
import sys
import multiprocessing

from fastcore.all import *
from fastai.basics import *
from fastai.text.all import *
from fastai.distributed import *

from reformer_fastai.all import *
 

Overview

The functions used for the entire datapipline, all model loading and model training can be found here. Click on the "[source]" links to see the full code. Full source code for the experiment script itself can be seen below.

Running the script

Experiments are run with this script by specifying:

1) The task to run, i.e. synthetic task, language modelling or translation 2) (Optionally) override default parameters for the dataloaders, models, training loop and logging

To run the training script run run_exp from within the reformer_fastai repo. For example:

run_exp 'synth' lr=1e-4 bs=32

To run experiment script on multiple GPUs use fastai.launch:

python -m fastai.launch [--gpus 1,2] expscript.py [args]

Experiment Configs

All model and training hypermaramters used in training can be found in Experiments/Configs

Data

enwik8 Data Download

download_enwik8_data[source]

download_enwik8_data(data_path='./data')

WMT-14 Data Download

download_wmt14_data[source]

download_wmt14_data(data_path='./data')

Dataloaders

Twin Sequence Dataloader

get_twin_sequence_dataloaders[source]

get_twin_sequence_dataloaders(bs:int=32, sl:int=1024, train_sz:int=500, valid_sz:int=100, seed=None)

 enwik8 Dataloader

val_test_chars sets the the number of tokens in the combined validation and test set. Valdiation and test sets will have val_test_chars / 2 tokens each

get_enwik8_dataloader[source]

get_enwik8_dataloader(data_path='data', bs:int=8, val_bs:int=16, sl:int=1024, n_workers=None, val_test_chars:int=10000000.0, verbose=False, tiny=False, small=False)

 WMT-14 Dataloader

get_wmt14_dataloader[source]

get_wmt14_dataloader(data_path='data', bs:int=8, val_bs:int=8, sl:int=1024, n_workers=None, verbose=False, tiny=False)

Learner

Sythetic Task Learner

get_synthetic_learner[source]

get_synthetic_learner(dls, model, precision=0)

enwik8 Language Modelling Task Learner

get_lm_learner[source]

get_lm_learner(dls, model, opt_func=adafactor, precision=0)

ReformerLM Learner

get_reformerlm_learner[source]

get_reformerlm_learner(dls, model, opt_func=adafactor, precision=2)

WMT Learner

get_seq2seq_learner[source]

get_seq2seq_learner(dls, model, tok, precision=0)

Logging

init_wandb[source]

init_wandb(cbs:list=[], wandb_name:str='', wandb_group:str='', wandb_notes:str='', wandb_tags:str='test', save_model=False)

Training Script

Command line arguments

Only arguments used to alternate between different experiment runs will be passed to the model from the command line, e.g. for the Synthetic experiment, only n_hashes and use_lsh can be changed from the command line. All other model parameters are fixed from SyntheticConfig

run_exp[source]

run_exp(task:"Task options: 'synt','lm_base','lm_rev',lm_shared_qk, n_hashes, n_layers, wmt_rev, wmt_base", data_path:"Path to data folder", n_epochs:"Number of epochs", lr:"Learning rate", bs:"Batch size", train_sz:"TwinSequence train size", valid_sz:"TwinSequence valid size", n_layers:"Number of layers", n_hashes:"Number of LSH Attention hashes", use_lsh:"Use LSH Attention", max_seq_len:"Max sequence length for model embedding and dataloader", do_wandb_logging:"Use weights and biases logging", run_name:"Run name for wandb tracking and model filename", wandb_group:"wandb group", wandb_notes:"wandb notes", wandb_tags:"wandb tags, add tags in a single string, space separated", save_model:"Save model locally in /models", grad_accum:"Gradient Accumulation, set greater than 1 to implement", clip:"Gradient Clipping, will be set if > 0.0", cuda_id:"Which cuda device to use", seed:"Set seed for reproducibiltiy, passing anything except 0 will use fastai's set_seed", distrib:"Set to True if using distributed training", verbose:"Print script logs", tiny:"Use 5% of data, for quick iteration and testings", precision:"0:fp16, 1:non native fp16, 2:fp32")

Task options: 'synt','lm_base','lm_rev',lm_shared_qk, trans

@call_parse
def run_exp(task:Param(help="Task options: 'synt','lm_base','lm_rev',lm_shared_qk, n_hashes, n_layers, wmt_rev, wmt_base", type=str),
         data_path:Param(help="Path to data folder", type=str, default='./data'),
         n_epochs:Param(help="Number of epochs", type=int, default=1),
         lr:Param(help="Learning rate", type=float, default=1e-3),               
         bs:Param(help="Batch size", type=int, default=64),
         train_sz:Param(help="TwinSequence train size", type=int, default=12800),
         valid_sz:Param(help="TwinSequence valid size", type=int, default=1280),
         n_layers:Param(help="Number of layers", type=int, default=3),
         n_hashes:Param(help="Number of LSH Attention hashes", type=int, default=1),
         use_lsh:Param(help="Use LSH Attention", type=bool_arg, default=False),
         max_seq_len:Param(help="Max sequence length for model embedding and dataloader", type=int, default=2048),
         do_wandb_logging:Param(help="Use weights and biases logging", type=bool_arg, default=False),
         run_name:Param(help="Run name for wandb tracking and model filename", type=str, default=''),
         wandb_group:Param(help="wandb group", type=str, default='TEST'),
         wandb_notes:Param(help="wandb notes", type=str, default='My experiment notes'),
         wandb_tags:Param(help="wandb tags, add tags in a single string, space separated", type=str, default='test'),
         save_model:Param(help="Save model locally in /models", type=bool_arg, default=False),
         grad_accum:Param(help="Gradient Accumulation, set greater than 1 to implement", type=int, default=1),
         clip:Param(help="Gradient Clipping, will be set if > 0.0", type=float, default=0.0),
         cuda_id:Param(help="Which cuda device to use", type=int, default=0),
         seed:Param(help="Set seed for reproducibiltiy, passing anything except 0 will use fastai's set_seed", type=int, default=0),
         distrib:Param(help="Set to True if using distributed training", type=bool_arg, default=False),
         verbose:Param(help="Print script logs", type=bool_arg, default=True),
         tiny:Param(help="Use 5% of data, for quick iteration and testings", type=bool_arg, default=False),
         precision:Param(help="0:fp16, 1:non native fp16, 2:fp32", type=int, default=0)
        ):

    """Task options: 'synt','lm_base','lm_rev',lm_shared_qk, trans"""
    #Set up distributed training
#     _wrapper = rank0_first if distrib else partial
#     if distrib: cuda_id = None 
    torch.cuda.set_device(cuda_id)
    
    # Callbacks used for training
    cbs = []
    if save_model: cbs.append(SaveModelCallback(every_epoch=True, with_opt=True))
    
    #random seeds
    if seed!=0:
        set_seed(seed, reproducible=True)  # this  sets `torch.cudnn.backends ++`
    else: 
        seed = None   # this is passed to LSH and data generator. They expect None or int
    
    if task == 'synt':
        "Model + Data Args than can be changed from command line: train_sz, valid_sz, n_hashes, use_lsh, seed"
        
        
        if run_name == '': 
            if use_lsh: run_name = f'{task}_lsh-{n_hashes}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
            else: run_name = f'{task}_full-attn_bs-{bs}_n_eps-{n_epochs}'
        
        print('Getting model ...')
        config = SyntheticConfig(warn=False, verbose=verbose, n_hashes=n_hashes, use_lsh=use_lsh, seed=seed)
        if verbose: print(config)
        config.save(run_name, add_tstmp=True)
        model = LSHLM.from_config(config)
        print('done!')
        
        print('Getting dataloaders ...')
        if train_sz != 12800: print(f'Note, "train_sz" changed from recommended 12800 to {train_sz}')
        dls = get_twin_sequence_dataloaders(bs=bs, sl=config['max_seq_len'], train_sz=train_sz, 
                                            valid_sz=valid_sz, seed=seed)
        print('done!')
        
        print('Getting learner ...')
        learn = get_synthetic_learner(dls, model, precision)
        print('done!')
        
        # Set up Weights & Biases logging, if needed
        if do_wandb_logging and rank_distrib()==0:
            wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
                                        wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
        
        # Append training callbacks needed
        cbs.append(MaskTargCallback())
        
        # Start training
        print('Starting training...')
        learn.fit_one_cycle(n_epochs, lr, cbs=cbs)
        print('done!')
        
        # Close wandb logging for this run
        if do_wandb_logging: wandb_run.finish()
        
        # Save model weights if needed, saved in /models relative to where script is run
        if save_model:
            now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
            learn.save(f'{task}_{run_name}_{now}')
    
    elif 'lm' in task:
        "Model args that can be changed from command line: axial_shape, max_seq_len"
        axial_shape = get_axial_shape(max_seq_len)
        if task == 'lm_base':
            if run_name == '': run_name = f'{task}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
            config = TransformerLMConfigEnwik8(warn=False, verbose=verbose, 
                                               axial_shape=axial_shape, max_seq_len=max_seq_len)
            print('Getting model ...')
            model = TransformerLM.from_config(config)
            print('done!')
        elif task == 'lm_rev':
            if run_name == '': run_name = f'{task}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
            config = ReversibleLMConfigEnwik8(warn=False, verbose=verbose, 
                                              axial_shape=axial_shape, max_seq_len=max_seq_len)
            print('Getting model ...')
            model = ReversibleLM.from_config(config)
            print('done!')
        elif task == 'lm_shared_qk':
            if run_name == '': run_name = f'{task}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
            config = TransformerLMConfigEnwik8(warn=False, verbose=verbose, shared_qk=True, 
                                               axial_shape=axial_shape, max_seq_len=max_seq_len)
            print('Getting model ...')
            model = TransformerLM.from_config(config)
            print('done!')
        
        if verbose: print(config)
        config.save(run_name, add_tstmp=True)

        print('Checking data')
#         _wrapper(download_enwik8_data, data_path=data_path)
#         if distrib: rank0_first(download_enwik8_data, data_path=data_path)
        download_enwik8_data(data_path=data_path)
        print('done')
        
        print('Getting dataloaders ...')
        dls = get_enwik8_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len, 
                                    verbose=verbose, tiny=tiny)
        print('done')
        
        print('Getting learner ...')
        learn = get_lm_learner(dls, model, opt_func=adafactor, precision=precision)
        print('done!')
        
        # CALLBACKS
        ## Gradient Clipping
        if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
        
        ## Gradient Accumulation
        if grad_accum > 1:
            print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
            cbs.append(GradientAccumulation(n_acc=grad_accum))
            run_name = run_name + f'_grad-accum-{grad_accum}'
        
        # Set up Weights & Biases logging, if needed
        if do_wandb_logging and rank_distrib()==0:
            wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
                                        wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
        
        # Start training
        print('Starting training...')
        learn.fit(n_epochs, cbs=cbs)
        print('done!')
        
        # Close wandb logging for this run
        if do_wandb_logging: wandb_run.finish()
        
        # Save model weights if needed, saved in /models relative to where script is run
        if save_model:
            now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
            learn.save(f'{task}_{run_name}_{now}')
    
    elif task == 'n_hashes':
        "Model args that can be changed from command line: n_hashes, seed"
        
        if run_name == '': run_name = f'{task}-{n_hashes}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'

        print('Checking data')
#         _wrapper(download_enwik8_data, data_path=data_path)
#         if distrib: rank0_first(download_enwik8_data, data_path=data_path)
        download_enwik8_data(data_path=data_path)
        print('done')
        
        print('Getting dataloaders ...')
        dls = get_enwik8_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len, 
                                    verbose=verbose, tiny=tiny)
        print('done')
        pad_id = dls.byte_text_tokenizer.pad_token_id
        
        config = NHashesConfig(warn=False, verbose=verbose, n_hashes=n_hashes,
                               seed=seed, pad_idx=pad_id)
        print('Getting model ...')
        model = LSHLM.from_config(config)
        print('done!')
        
        if verbose: print(config)
        config.save(run_name, add_tstmp=True)
        
        print('Getting learner ...')
        learn = get_lm_learner(dls, model, opt_func=adafactor, precision=precision)
        print('done!')
        
        # CALLBACKS
        ## Gradient Clipping
        if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
        
        ## Gradient Accumulation
        if grad_accum > 1: 
            print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
            cbs.append(GradientAccumulation(n_acc=grad_accum))
            run_name = run_name + f'_grad-accum-{grad_accum}'
        #LSH-specific callback
        if config.use_lsh: cbs.append(PadBatchCallback(bucket_size=config.bucket_size,
                                                       val=pad_id, y_val=pad_id))
        # Set up Weights & Biases logging, if needed
        if do_wandb_logging and rank_distrib()==0:
            wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
                                        wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
        
        # Start training
        print('Starting training...')
        learn.fit(n_epochs, cbs=cbs)
        print('done!')
        
        # Close wandb logging for this run
        if do_wandb_logging: wandb_run.finish()
        
        # Save model weights if needed, saved in /models relative to where script is run
        if save_model:
            now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
            learn.save(f'{task}_{run_name}_{now}')
    
    elif task == 'n_layers':
        "Model args that can be changed from command line: n_hashes, seed"
        
        if run_name == '': run_name = f'{task}-{n_layers}_enwik8_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'

        print('Checking data')
#         _wrapper(download_enwik8_data, data_path=data_path)
#         if distrib: rank0_first(download_enwik8_data, data_path=data_path)
        download_enwik8_data(data_path=data_path)
        print('done')
        
        print('Getting dataloaders ...')
        dls = get_enwik8_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len, 
                                    verbose=verbose, tiny=tiny)
        print('done')
        pad_id = dls.byte_text_tokenizer.pad_token_id
        
        config = NLayersConfig(warn=False, verbose=verbose, n_layers=n_layers,
                               max_seq_len=max_seq_len, seed=seed, pad_idx=pad_id)
        print('Getting model ...')
        model = ReformerLM.from_config(config)
        print('done!')
        
        if verbose: print(config)
        config.save(run_name, add_tstmp=True)
        
        print('Getting learner ...')
        learn = get_reformerlm_learner(dls, model, opt_func=adafactor, precision=precision)
        print('done!')
        
        # CALLBACKS
        ## Gradient Clipping
        if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
        
        ## Gradient Accumulation
        if grad_accum > 1: 
            print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
            cbs.append(GradientAccumulation(n_acc=grad_accum))
            run_name = run_name + f'_grad-accum-{grad_accum}'
        #LSH-specific callback
        if config.use_lsh: cbs.append(PadBatchCallback(bucket_size=config.bucket_size,
                                                       val=pad_id, y_val=pad_id))
        # Set up Weights & Biases logging, if needed
        if do_wandb_logging and rank_distrib()==0:
            wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
                                        wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
        
        # Start training
        print('Starting training...')
        learn.fit(n_epochs, cbs=cbs)
        print('done!')
        
        # Close wandb logging for this run
        if do_wandb_logging: wandb_run.finish()
        
        # Save model weights if needed, saved in /models relative to where script is run
        if save_model:
            now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
            learn.save(f'{task}_{run_name}_{now}')
    
    
    elif 'wmt' in task:
        "Model args that can be changed from command line: n_layers, max_seq_len"
        axial_shape = get_axial_shape(max_seq_len)
        if run_name == '': run_name = f'{task}_sl-{max_seq_len}_bs-{bs}_n_eps-{n_epochs}_seed-{seed}'
            
        print('Checking data')
        download_wmt14_data(data_path=data_path)
        print('done')
        
        print('Getting dataloaders and tokenizer ...')
        dls, tok = get_wmt14_dataloader(data_path=data_path, bs=bs, val_bs=bs, sl=max_seq_len, 
                                           verbose=verbose, tiny=tiny)
        print('done')
        
        print('Getting model ...')
        if task == 'wmt_rev':
            config = ReversibleTransformerConfigWMT(warn=False, verbose=verbose, 
                                                    enc_vocab_sz=tok.vocab_size, dec_vocab_sz=tok.vocab_size, pad_idx=tok.PAD_ID,
                                                    n_enc_layers=n_layers, n_dec_layers=n_layers)
            model = ReversibleTransformer.from_config(config)
        elif task == 'wmt_base':
            config = TransformerConfigWMT(warn=False, verbose=verbose, 
                                            enc_vocab_sz=tok.vocab_size, dec_vocab_sz=tok.vocab_size, pad_idx=tok.PAD_ID,
                                            n_enc_layers=n_layers, n_dec_layers=n_layers)

            model = Transformer.from_config(config) 
        print('done!')
        
        if verbose: print(config)
        config.save(run_name, add_tstmp=True)

        print('Getting learner ...')    
        learn = get_seq2seq_learner(dls, model, tok, precision)
        print('done!')
        
        # CALLBACKS
        cbs += [CombineInputOutputCallback(), LossTargetShiftCallback(), RemoveEOSCallback(eos_idx=tok.EOS_ID)]
        
        ## Gradient Clipping Callback
        if clip != 0.0: cbs.append(GradientClip(max_norm=clip))
        
        ## Gradient Accumulation Callback
        if grad_accum > 1:
            print(f'Gradient accumulation on, virtual batch size == {grad_accum}')
            cbs.append(GradientAccumulation(n_acc=grad_accum))
            run_name = run_name + f'_grad-accum-{grad_accum}'
        
        # Set up Weights & Biases logging, if needed
        if do_wandb_logging:
            wandb_run, cbs = init_wandb(cbs, wandb_name=run_name, wandb_group=wandb_group,
                                        wandb_notes=wandb_notes, wandb_tags=wandb_tags, save_model=save_model)
        
        # Start training
        print('Starting training...')
        learn.fit_one_cycle(n_epochs, lr, cbs=cbs)
        print('done!')
        
        # Close wandb logging for this run
        if do_wandb_logging: wandb_run.finish()
        
        # Save model weights if needed, saved in /models relative to where script is run
        if save_model:
            now = time.strftime("_%d_%m_%Y_%H:%M", time.gmtime())
            learn.save(f'{task}_{run_name}_{now}')
    
    
    elif task == 'test_cfg':
        print('Locals ', locals())
        print()
        config = SyntheticConfig(verbouse=True, **locals())
        print(config)
        config.save('test')
        config2 = SyntheticConfig.from_file('test')
        print(config2)
        
    elif task == 'test':
        print('testing testing :)')
        print(verbose)
        
    else:
        print('No task run')

Running the Script

Example command to run full scale experiment, note than run_name can be passed, but if not passed it will be automatically constructed based on the task and relevant arguments

Running the Synthetic Experiment:

run_exp 'synt' \
    --n_epochs=750 \
    --lr=1e-4 \
    --bs=128 \
    --use_lsh=True \
    --n_hashes=1 \
    --train_sz=12800 \
    --valid_sz=1280 \
    --seed=1234 \
    --wandb_group='Synthetic' \
    --wandb_tags='synthetic_task lsh lm test' \
    --run_name='synth_lsh_1_hash'

Running the Reversible Language Model experiment:

For the full 60k steps with a sequence length of 65536, the number of epochs can be calculated as follows:with sl == 2**16, 1 epoch of enwik8 will have 172 batches, therefore; n_epoch == 60000/172 == 349

run_exp 'lm_rev' \
    --n_epochs=3 \
    --lr=1e-4 \
    --bs=8 \
    --max_seq_len=4096 \
    --do_wandb_logging=True \
    --wandb_group='enwik8_lm_rev' \
    --wandb_tags='lm_rev lm exp' \
    --wandb_notes='This is a test'
    --grad_accum=4