And its variations.
class EMA(nn.Module):
def __init__(self, size:Tuple[int], gamma:float):
super().__init__()
self.register_buffer("avg", torch.zeros(*size))
self.gamma = gamma
self.cor = 1
def update(self, val):
self.cor *= self.gamma
self.avg += (val - self.avg) * (1-self.gamma)
@property
def value(self):
return self.avg / (1. - self.cor)
def updated_value(self, val):
self.update(val)
return self.value
class VQVAE(nn.Module):
def __init__(self, encoder, decoder, k:int, d:int, commitment_cost:float=0.25, use_ema:bool=False):
super().__init__()
self.encoder, self.decoder = encoder, decoder
self.quantize = (VectorQuantizerEMA(k, d, commitment_cost) if use_ema else
VectorQuantizer(k, d, commitment_cost))
def forward(self, x):
ze = self.encoder(x)
zq, vq_loss, code = self.quantize(ze)
x_hat = self.decoder(zq)
return x_hat, vq_loss, code
@torch.no_grad()
def encode(self, x):
ze = self.ecoder(x)
_, _, code = self.quantize(ze)
return code
@torch.no_grad()
def decode(self, code):
zq = F.embedding(code, self.quantize.embedding)
if zq.dim() == 4:
zq = zq.permute(0,3,1,2).contiguous()
return self.decoder(zq)