add variation inference training given to core team
This commit is contained in:
parent
6fe9044cdc
commit
8bb99e507f
4
trainings/variational_inference_core_team/.gitignore
vendored
Normal file
4
trainings/variational_inference_core_team/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
.ipynb_checkpoints/
|
||||||
|
.idea/
|
||||||
|
__pycache__/
|
||||||
|
data/
|
File diff suppressed because one or more lines are too long
BIN
trainings/variational_inference_core_team/img/auto-encoder.png
Normal file
BIN
trainings/variational_inference_core_team/img/auto-encoder.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 134 KiB |
63
trainings/variational_inference_core_team/models.py
Normal file
63
trainings/variational_inference_core_team/models.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoder(nn.Module):
|
||||||
|
def __init__(self, input_size=784, z_size=20):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = int((input_size - z_size) / 2 + z_size)
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(input_size, hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, z_size)
|
||||||
|
)
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Linear(z_size, hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, input_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.view(-1, 784)
|
||||||
|
z = self.encoder(x)
|
||||||
|
x = self.decoder(z)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return F.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class VAE(nn.Module):
|
||||||
|
def __init__(self, input_size=784, z_size=20):
|
||||||
|
super().__init__()
|
||||||
|
hidden_size = int((input_size - z_size) / 2 + z_size)
|
||||||
|
self.z_size = z_size
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(input_size, hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, z_size * 2)
|
||||||
|
)
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Linear(z_size, hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, input_size),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def reparameterize(self, mu, log_var):
|
||||||
|
std = torch.exp(0.5 * log_var)
|
||||||
|
eps = torch.randn_like(std) # unit gaussian
|
||||||
|
z = mu + eps * std
|
||||||
|
return z
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.view(-1, 784)
|
||||||
|
variational_params = self.encoder(x)
|
||||||
|
mu = variational_params[..., :self.z_size]
|
||||||
|
log_var = variational_params[..., self.z_size:]
|
||||||
|
z = self.reparameterize(mu, log_var)
|
||||||
|
return self.decoder(z), z, mu, log_var
|
@ -0,0 +1,5 @@
|
|||||||
|
torch==1.3.0
|
||||||
|
torchvision>=0.4
|
||||||
|
numpy==1.17.2
|
||||||
|
matplotlib==3.1.1
|
||||||
|
scipy==1.3.1
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user