Files

64 lines
1.8 KiB
Python

import torch
from torch import nn
import torch.nn.functional as F
class AutoEncoder(nn.Module):
def __init__(self, input_size=784, z_size=20):
super().__init__()
hidden_size = int((input_size - z_size) / 2 + z_size)
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, z_size)
)
self.decoder = nn.Sequential(
nn.Linear(z_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size),
)
def forward(self, x):
x = x.view(-1, 784)
z = self.encoder(x)
x = self.decoder(z)
if self.training:
return x
else:
return F.sigmoid(x)
class VAE(nn.Module):
def __init__(self, input_size=784, z_size=20):
super().__init__()
hidden_size = int((input_size - z_size) / 2 + z_size)
self.z_size = z_size
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, z_size * 2)
)
self.decoder = nn.Sequential(
nn.Linear(z_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size),
nn.Sigmoid()
)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std) # unit gaussian
z = mu + eps * std
return z
def forward(self, x):
x = x.view(-1, 784)
variational_params = self.encoder(x)
mu = variational_params[..., :self.z_size]
log_var = variational_params[..., self.z_size:]
z = self.reparameterize(mu, log_var)
return self.decoder(z), z, mu, log_var