Masked Autoencoder for Distribution Estimation https://arxiv.org/abs/1502.03509.
d, k = 2,4
m = torch.randint(0, d-1, (k,))
i = torch.arange(d)
j = torch.arange(k)
input_mask = torch.where(m[j][..., None] >= i[None], 1., 0.)
input_mask
output_mask = torch.where(i[..., None] > m[j][None], 1., 0.)
output_mask
output_mask @ input_mask
masks = make_masks(2, [4]*3)
torch.linalg.multi_dot(masks[::-1]).bool().float()
for mask in masks:
print(mask.shape)
fig, axs = plt.subplots(1, len(masks))
for ax, mask in zip(axs, masks):
ax.matshow(mask)
plt.show()
m = MaskedLinear(d,k)
m.set_mask(masks[0])
m.mask
x = torch.rand(5, d)
out1 = m(x)
x[:, 1] = torch.rand(5)
out2 = m(x)
assert torch.all(out1 == out2)
out2
model = SimpleMADE(2, 4, 2, True)
model
x = torch.rand(5, 2)
model(x)
x=torch.arange(10).reshape(2,5)
x.repeat(1,2)
x = torch.rand(5, 2)
model = MADE(2, 4, 3, outs_per_input=2)
model(x)
d_in = 4
outs_per_input = 2
direct_mask = torch.repeat_interleave(
torch.tril(torch.ones(d_in, d_in), diagonal=-1), outs_per_input, 0
)
direct = MaskedLinear(d_in, d_in*outs_per_input, direct_mask)
direct_mask