Common layers, blocks and utils.
model = MLP(5, 10, 16, n_layers=3)
x = torch.randn(4, 5)
out = model(x)
assert out.shape == (4, 10)
bs, c_in, c_out, h, w = 4, 3, 8, 4, 4
conv = Conv2dBlock(c_in, c_out, 3, 2)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, (h+1)//2, (w+1)//2)
bs, c_in, c_out, h, w = 4, 16, 8, 10, 10
conv = ConvTranspose2dBlock(c_in, c_out, 4, 2)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, h*2, w*2)
bs, c_in, c_out, h, w = 4, 3, 8, 24, 24
conv = ResBlock(c_in, c_out, 3, 1)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, h, w)
x = torch.randn(1, 3, 2, 2)
m = ChanLayerNorm(3)
out = m(x)
mu = out.mean(1)
assert torch.allclose(mu+1, torch.ones_like(mu))
model = ConvNet(1)
model
class ResNet(nn.Module):
def __init__(self, c_in):
super().__init__()
self.net = nn.Sequential(
Conv2dBlock(c_in, 256, 4, 2),
nn.Conv2d(256, 256, 4, 2, 1),
ResBlock(256, 256, (3,1), 1, activation=Identity),
ResBlock(256, 256, (3,1), 1, activation=Identity)
)
def forward(self, x):
return self.net(x)