64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class AutoEncoder(nn.Module):
|
|
def __init__(self, input_size=784, z_size=20):
|
|
super().__init__()
|
|
hidden_size = int((input_size - z_size) / 2 + z_size)
|
|
self.encoder = nn.Sequential(
|
|
nn.Linear(input_size, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, z_size)
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
nn.Linear(z_size, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, input_size),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x.view(-1, 784)
|
|
z = self.encoder(x)
|
|
x = self.decoder(z)
|
|
|
|
if self.training:
|
|
return x
|
|
else:
|
|
return F.sigmoid(x)
|
|
|
|
|
|
class VAE(nn.Module):
|
|
def __init__(self, input_size=784, z_size=20):
|
|
super().__init__()
|
|
hidden_size = int((input_size - z_size) / 2 + z_size)
|
|
self.z_size = z_size
|
|
self.encoder = nn.Sequential(
|
|
nn.Linear(input_size, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, z_size * 2)
|
|
)
|
|
self.decoder = nn.Sequential(
|
|
nn.Linear(z_size, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, input_size),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def reparameterize(self, mu, log_var):
|
|
std = torch.exp(0.5 * log_var)
|
|
eps = torch.randn_like(std) # unit gaussian
|
|
z = mu + eps * std
|
|
return z
|
|
|
|
def forward(self, x):
|
|
x = x.view(-1, 784)
|
|
variational_params = self.encoder(x)
|
|
mu = variational_params[..., :self.z_size]
|
|
log_var = variational_params[..., self.z_size:]
|
|
z = self.reparameterize(mu, log_var)
|
|
return self.decoder(z), z, mu, log_var
|