94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
import matplotlib.pyplot as plt
|
|
import torch
|
|
from torch import nn
|
|
from torch import distributions as dist
|
|
from flows import Planar
|
|
|
|
|
|
def target_density(z):
|
|
z1, z2 = z[..., 0], z[..., 1]
|
|
norm = (z1**2 + z2**2)**0.5
|
|
exp1 = torch.exp(-0.2 * ((z1 - 2) / 0.8) ** 2)
|
|
exp2 = torch.exp(-0.2 * ((z1 + 2) / 0.8) ** 2)
|
|
u = 0.5 * ((norm - 4) / 0.4) ** 2 - torch.log(exp1 + exp2)
|
|
return torch.exp(-u)
|
|
|
|
|
|
class Flow(nn.Module):
|
|
def __init__(self, dim=2, n_flows=10):
|
|
super().__init__()
|
|
self.flow = nn.Sequential(*[
|
|
Planar(dim) for _ in range(n_flows)
|
|
])
|
|
self.mu = nn.Parameter(torch.randn(dim, ).normal_(0, 0.01))
|
|
self.log_var = nn.Parameter(torch.randn(dim, ).normal_(1, 0.01))
|
|
|
|
def forward(self, shape):
|
|
std = torch.exp(0.5 * self.log_var)
|
|
eps = torch.randn(shape) # unit gaussian
|
|
z0 = self.mu + eps * std
|
|
|
|
zk, ldj = self.flow(z0)
|
|
return z0, zk, ldj, self.mu, self.log_var
|
|
|
|
|
|
def det_loss(mu, log_var, z_0, z_k, ldj, beta):
|
|
# Note that I assume uniform prior here.
|
|
# So P(z) is constant and not modelled in this loss function
|
|
batch_size = z_0.size(0)
|
|
|
|
# Qz0
|
|
log_qz0 = dist.Normal(mu, torch.exp(0.5 * log_var)).log_prob(z_0)
|
|
# Qzk = Qz0 + sum(log det jac)
|
|
log_qzk = log_qz0.sum() - ldj.sum()
|
|
# P(x|z)
|
|
nll = -torch.log(target_density(z_k) + 1e-7).sum() * beta
|
|
return (log_qzk + nll) / batch_size
|
|
|
|
|
|
def train_flow(flow, shape, epochs=1000):
|
|
optim = torch.optim.Adam(flow.parameters(), lr=1e-2)
|
|
|
|
for i in range(epochs):
|
|
z0, zk, ldj, mu, log_var = flow(shape=shape)
|
|
loss = det_loss(mu=mu,
|
|
log_var=log_var,
|
|
z_0=z0,
|
|
z_k=zk,
|
|
ldj=ldj,
|
|
beta=1)
|
|
loss.backward()
|
|
optim.step()
|
|
optim.zero_grad()
|
|
if i % 100 == 0:
|
|
print(loss.item())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import numpy as np
|
|
|
|
x1 = np.linspace(-7.5, 7.5)
|
|
x2 = np.linspace(-7.5, 7.5)
|
|
x1_s, x2_s = np.meshgrid(x1, x2)
|
|
x_field = np.concatenate([x1_s[..., None], x2_s[..., None]], axis=-1)
|
|
x_field = torch.tensor(x_field, dtype=torch.float)
|
|
|
|
plt.figure(figsize=(8, 8))
|
|
plt.title("Target distribution")
|
|
plt.xlabel('$z_1$')
|
|
plt.ylabel('$z_2$')
|
|
plt.contourf(x1_s, x2_s, target_density(x_field))
|
|
plt.show()
|
|
|
|
def show_samples(s):
|
|
plt.figure(figsize=(6, 6))
|
|
plt.scatter(s[:, 0], s[:, 1], alpha=0.1)
|
|
plt.show()
|
|
|
|
flow = Flow(dim=2, n_flows=10)
|
|
shape = (1000, 2)
|
|
train_flow(flow, shape, epochs=5000)
|
|
z0, zk, ldj, mu, log_var = flow((5000, 2))
|
|
show_samples(zk.data)
|
|
|