import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from generative_models.layers import scale, unscale
from torchvision.datasets import MNIST
import torchvision.transforms as T
dir = "/media/arto/work/data/mnist"
from pathlib import Path
path = Path(dir)
path.exists()
True
valid_dl = DataLoader(MNIST(dir, transform=T.ToTensor(), train=False), batch_size=32, shuffle=False, drop_last=False)
/home/arto/anaconda3/envs/torchenv/lib/python3.8/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448234945/work/torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
b = next(iter(valid_dl))
b[0].shape, b[1].shape
(torch.Size([32, 1, 28, 28]), torch.Size([32]))
from generative_models.layers import ConvNet
model = nn.Sequential(
    ConvNet(1),
    nn.Flatten(),
    nn.BatchNorm1d(64*2*2),
    nn.ReLU(),
    nn.Linear(64*2*2, 10)
)
out = model(b[0])
out.shape
torch.Size([32, 10])
from tqdm.auto import tqdm, trange
def train_step(batch, model, optimizer, loss_func, scheduler=None, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb, yb = batch[0].to(device), batch[1].to(device)
    preds = model(xb)

    loss = loss_func(preds, yb)
    
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    optimizer.zero_grad()
    return loss.item()
def accuracy(pred, targ):
    return (pred.argmax(-1) == targ).float().mean()
def eval_step(batch, model, loss_func, device=None):
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb, yb = batch[0].to(device), batch[1].to(device)
    with torch.no_grad():
        preds = model(xb)
        loss = loss_func(preds, yb)
    return loss.item(), accuracy(preds, yb).item()
def fit(n_epoch, model, train_dl, valid_dl, train_step, eval_step, optimizer, loss_func, scheduler=None, device=None):
    
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    steps_per_epoch = len(train_dl)
    total_steps = n_epoch * steps_per_epoch
    train_losses = []
    valid_losses = []
    valid_accs = []
    for e in trange(n_epoch):
        
        model.train()
        train_pbar = tqdm(train_dl, leave=False)
        for step, batch in enumerate(train_pbar):
            total_step = (e*steps_per_epoch)+step
            loss = train_step(batch, model, optimizer, loss_func, scheduler, device=device)
            train_losses.append(loss)
            train_pbar.set_description(f"{loss:.2f}")

        model.eval()
        avg_valid_loss = 0
        for step, batch in enumerate(valid_dl):
            loss = eval_step(batch, model, loss_func, device=device)
            avg_valid_loss += (loss-avg_valid_loss) / (step+1)
        valid_losses.append(avg_valid_loss)
    return train_losses, valid_losses, valid_accs
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

tl, vl, va = fit(2, model, train_dl, valid_dl, train_step, eval_step, optimizer, nn.CrossEntropyLoss())
100%|██████████| 2/2 [00:30<00:00, 15.36s/it]
va
[0.9798259493670884, 0.9860561708860759]
from generative_models.pixelcnn import SimplePixelCNN
model = SimplePixelCNN(ks=5)
def train_step(batch, model, optimizer, loss_func, scheduler=None, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb = scale(batch[0].to(device))
    preds = model(xb)

    loss = loss_func(preds, xb)
    
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    optimizer.zero_grad()
    return loss.item()
def eval_step(batch, model, loss_func, device=None):
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb = scale(batch[0].to(device))
    with torch.no_grad():
        preds = model(xb)
        loss = loss_func(preds, xb)
    return loss.item()
optimizer = torch.optim.Adam(model.parameters())
fit(1, model, valid_dl, valid_dl, train_step, eval_step, optimizer, nn.MSELoss())
100%|██████████| 1/1 [01:13<00:00, 73.23s/it]
([0.025165459141135216,
  0.0241972878575325,
  0.026837032288312912,
  0.02695108950138092,
  0.024686306715011597,
  0.024144113063812256,
  0.024485789239406586,
  0.025522911921143532,
  0.023379918187856674,
  0.02602148987352848,
  0.022684287279844284,
  0.023620959371328354,
  0.024264581501483917,
  0.025816742330789566,
  0.03121105395257473,
  0.02405347116291523,
  0.02727814018726349,
  0.025518158450722694,
  0.024460557848215103,
  0.023682404309511185,
  0.02358895353972912,
  0.022790253162384033,
  0.02382190153002739,
  0.020986245945096016,
  0.02433934435248375,
  0.027637461200356483,
  0.021911630406975746,
  0.027573779225349426,
  0.02243075519800186,
  0.02319721318781376,
  0.021680839359760284,
  0.024634309113025665,
  0.024297423660755157,
  0.02423711121082306,
  0.024424802511930466,
  0.02504887990653515,
  0.02586372010409832,
  0.024227755144238472,
  0.024355586618185043,
  0.025468679144978523,
  0.022759485989809036,
  0.024881746619939804,
  0.02257918380200863,
  0.025725701823830605,
  0.026492837816476822,
  0.02530582621693611,
  0.02390533685684204,
  0.025783071294426918,
  0.0238608680665493,
  0.025114314630627632,
  0.023457864299416542,
  0.021376343443989754,
  0.02280452661216259,
  0.02397243119776249,
  0.023884544149041176,
  0.022695979103446007,
  0.023002570495009422,
  0.02118225023150444,
  0.021862516179680824,
  0.024464109912514687,
  0.023088745772838593,
  0.027473362162709236,
  0.02324155904352665,
  0.022978996858000755,
  0.024906586855649948,
  0.026165805757045746,
  0.021497782319784164,
  0.02486112155020237,
  0.026769131422042847,
  0.024408008903265,
  0.022072896361351013,
  0.025104960426688194,
  0.02379155531525612,
  0.023406982421875,
  0.02602103166282177,
  0.02647869661450386,
  0.023457717150449753,
  0.023323904722929,
  0.023111121729016304,
  0.025673648342490196,
  0.021267922595143318,
  0.023930219933390617,
  0.024572828784585,
  0.023443875834345818,
  0.024176975712180138,
  0.025413138791918755,
  0.023457955569028854,
  0.02512156218290329,
  0.024724697694182396,
  0.024355048313736916,
  0.025202030315995216,
  0.0233379527926445,
  0.02537531964480877,
  0.024419642984867096,
  0.024168873205780983,
  0.025372834876179695,
  0.021566633135080338,
  0.022312216460704803,
  0.02339460887014866,
  0.02335943654179573,
  0.02286761999130249,
  0.024069583043456078,
  0.02299763262271881,
  0.024788372218608856,
  0.02612309902906418,
  0.02105853706598282,
  0.02346302755177021,
  0.020712271332740784,
  0.02371399477124214,
  0.024197915568947792,
  0.02347218617796898,
  0.02439768612384796,
  0.0230949018150568,
  0.022821007296442986,
  0.02215844951570034,
  0.024015061557292938,
  0.025144703686237335,
  0.02578580379486084,
  0.023501185700297356,
  0.02657381258904934,
  0.024089213460683823,
  0.024063074961304665,
  0.02195393294095993,
  0.02583802491426468,
  0.021901212632656097,
  0.022325554862618446,
  0.02300158701837063,
  0.025041064247488976,
  0.026960279792547226,
  0.02522105537354946,
  0.021093396469950676,
  0.02450494095683098,
  0.025845075026154518,
  0.02267463319003582,
  0.02112670987844467,
  0.020963918417692184,
  0.023587577044963837,
  0.02317025698721409,
  0.022468430921435356,
  0.025481851771473885,
  0.024890322238206863,
  0.022805172950029373,
  0.02397787757217884,
  0.019549598917365074,
  0.02575957030057907,
  0.022585367783904076,
  0.021560844033956528,
  0.021041139960289,
  0.02599845640361309,
  0.020979739725589752,
  0.02433830313384533,
  0.02242334745824337,
  0.02542181871831417,
  0.02057849057018757,
  0.022531844675540924,
  0.021685779094696045,
  0.021656449884176254,
  0.021948736160993576,
  0.02299417555332184,
  0.021928494796156883,
  0.023646514862775803,
  0.02225734293460846,
  0.019448695704340935,
  0.020208723843097687,
  0.022729674354195595,
  0.021178023889660835,
  0.018310856074094772,
  0.018417319282889366,
  0.02125641703605652,
  0.02602696791291237,
  0.024881508201360703,
  0.020985197275877,
  0.020271971821784973,
  0.025957301259040833,
  0.02347690612077713,
  0.021598681807518005,
  0.023501837626099586,
  0.023947104811668396,
  0.02603798359632492,
  0.03159287944436073,
  0.031831443309783936,
  0.026348832994699478,
  0.0343405082821846,
  0.025207186117768288,
  0.023540304973721504,
  0.02413264848291874,
  0.026873739436268806,
  0.024189038202166557,
  0.024454500526189804,
  0.027473989874124527,
  0.02748546190559864,
  0.02459084242582321,
  0.0238484013825655,
  0.027678707614541054,
  0.02335665188729763,
  0.025333024561405182,
  0.02472667209804058,
  0.024130864068865776,
  0.025703273713588715,
  0.026213781908154488,
  0.02476911060512066,
  0.022833144292235374,
  0.02583354339003563,
  0.025308700278401375,
  0.03465501219034195,
  0.03803979977965355,
  0.035219017416238785,
  0.030053766444325447,
  0.024381550028920174,
  0.02685229666531086,
  0.025999223813414574,
  0.026855219155550003,
  0.023904580622911453,
  0.022557411342859268,
  0.023349924013018608,
  0.021553078666329384,
  0.02308019809424877,
  0.02422638237476349,
  0.026050951331853867,
  0.023471081629395485,
  0.02376311644911766,
  0.02385307289659977,
  0.027241097763180733,
  0.020413585007190704,
  0.025587936863303185,
  0.024397343397140503,
  0.024883849546313286,
  0.02238272689282894,
  0.02007526531815529,
  0.020831691101193428,
  0.024442030116915703,
  0.025325920432806015,
  0.024939093738794327,
  0.026198776438832283,
  0.023531440645456314,
  0.027303189039230347,
  0.024468688294291496,
  0.023554811254143715,
  0.02559882216155529,
  0.025031255558133125,
  0.023582009598612785,
  0.028882097452878952,
  0.0238957479596138,
  0.028554128482937813,
  0.02276710607111454,
  0.027564996853470802,
  0.021403536200523376,
  0.022596275433897972,
  0.022736798971891403,
  0.024414684623479843,
  0.022932883352041245,
  0.024850761517882347,
  0.022712837904691696,
  0.026667360216379166,
  0.027201592922210693,
  0.028042461723089218,
  0.025886354967951775,
  0.02472769282758236,
  0.02901526354253292,
  0.02679613046348095,
  0.028164340183138847,
  0.027648895978927612,
  0.026996655389666557,
  0.022718511521816254,
  0.02694959193468094,
  0.026176830753684044,
  0.02474903129041195,
  0.033413391560316086,
  0.029690910130739212,
  0.025709200650453568,
  0.025660455226898193,
  0.023668626323342323,
  0.021615786477923393,
  0.02161114476621151,
  0.021664949133992195,
  0.023062439635396004,
  0.022212594747543335,
  0.0198417566716671,
  0.023466579616069794,
  0.026169072836637497,
  0.02710454724729061,
  0.02432725764811039,
  0.026151597499847412,
  0.030023476108908653,
  0.028260376304388046,
  0.022534234449267387,
  0.022489726543426514,
  0.023435840383172035,
  0.025000054389238358,
  0.021607132628560066,
  0.021713221445679665,
  0.022083522751927376,
  0.021455204114317894,
  0.024627400562167168,
  0.026595305651426315,
  0.022242818027734756,
  0.02569037303328514,
  0.026715263724327087,
  0.024584496393799782,
  0.026699766516685486,
  0.03223850950598717,
  0.03106926567852497,
  0.031985681504011154,
  0.027505287900567055,
  0.026301559060811996,
  0.02268313430249691,
  0.018367260694503784,
  0.02044326439499855,
  0.017256345599889755,
  0.02240048348903656,
  0.027537135407328606,
  0.022784797474741936,
  0.024278679862618446],
 [0.024526611024055622],
 [])