planar flows working

This commit is contained in:
ritchie46
2019-10-12 15:38:45 +02:00
parent eebd9b5e3d
commit d6d3b7f18e
3 changed files with 530 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
import torch
from torch import nn
class Planar(nn.Module):
def __init__(self, size=1, init_sigma=0.01):
super().__init__()
self.u = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
self.w = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
self.b = nn.Parameter(torch.zeros(1))
@property
def normalized_u(self):
"""
Needed for invertibility condition.
See Appendix A.1
Rezende et al. Variational Inference with Normalizing Flows
https://arxiv.org/pdf/1505.05770.pdf
"""
# softplus
def m(x):
return -1 + torch.log(1 + torch.exp(x))
wtu = torch.matmul(self.w, self.u.t())
w_div_w2 = self.w / torch.norm(self.w)
return self.u + (m(wtu) - wtu) * w_div_w2
def psi(self, z):
"""
ψ(z) =h(w^tz+b)w
See eq(11)
Rezende et al. Variational Inference with Normalizing Flows
https://arxiv.org/pdf/1505.05770.pdf
"""
return self.h_prime(z @ self.w.t() + self.b) @ self.w
def h(self, x):
return torch.tanh(x)
def h_prime(self, z):
return 1 - torch.tanh(z) ** 2
def forward(self, z):
if isinstance(z, tuple):
z, accumulating_ldj = z
else:
z, accumulating_ldj = z, 0
psi = self.psi(z)
u = self.normalized_u
# determinant of jacobian
det = (1 + psi @ u.t())
# log |det Jac|
ldj = torch.log(torch.abs(det) + 1e-6)
wzb = z @ self.w.t() + self.b
fz = z + (u * self.h(wzb))
return fz, ldj + accumulating_ldj
if __name__ == '__main__':
import matplotlib.pyplot as plt
z0 = torch.rand((1000, 2))
with torch.no_grad():
pf = Planar(size=2)
zk = z0
for i in range(10):
zk, ldj = pf.forward(zk)
plt.scatter(zk[:, 0], zk[:, 1], alpha=0.2)
plt.show()

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,93 @@
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import distributions as dist
from flows import Planar
def target_density(z):
z1, z2 = z[..., 0], z[..., 1]
norm = (z1**2 + z2**2)**0.5
exp1 = torch.exp(-0.2 * ((z1 - 2) / 0.8) ** 2)
exp2 = torch.exp(-0.2 * ((z1 + 2) / 0.8) ** 2)
u = 0.5 * ((norm - 4) / 0.4) ** 2 - torch.log(exp1 + exp2)
return torch.exp(-u)
class Flow(nn.Module):
def __init__(self, dim=2, n_flows=10):
super().__init__()
self.flow = nn.Sequential(*[
Planar(dim) for _ in range(n_flows)
])
self.mu = nn.Parameter(torch.randn(dim, ).normal_(0, 0.01))
self.log_var = nn.Parameter(torch.randn(dim, ).normal_(1, 0.01))
def forward(self, shape):
std = torch.exp(0.5 * self.log_var)
eps = torch.randn(shape) # unit gaussian
z0 = self.mu + eps * std
zk, ldj = self.flow(z0)
return z0, zk, ldj, self.mu, self.log_var
def det_loss(mu, log_var, z_0, z_k, ldj, beta):
# Note that I assume uniform prior here.
# So P(z) is constant and not modelled in this loss function
batch_size = z_0.size(0)
# Qz0
log_qz0 = dist.Normal(mu, torch.exp(0.5 * log_var)).log_prob(z_0)
# Qzk = Qz0 + sum(log det jac)
log_qzk = log_qz0.sum() - ldj.sum()
# P(x|z)
nll = -torch.log(target_density(z_k) + 1e-7).sum() * beta
return (log_qzk + nll) / batch_size
def train_flow(flow, shape, epochs=1000):
optim = torch.optim.Adam(flow.parameters(), lr=1e-2)
for i in range(epochs):
z0, zk, ldj, mu, log_var = flow(shape=shape)
loss = det_loss(mu=mu,
log_var=log_var,
z_0=z0,
z_k=zk,
ldj=ldj,
beta=1)
loss.backward()
optim.step()
optim.zero_grad()
if i % 100 == 0:
print(loss.item())
if __name__ == '__main__':
import numpy as np
x1 = np.linspace(-7.5, 7.5)
x2 = np.linspace(-7.5, 7.5)
x1_s, x2_s = np.meshgrid(x1, x2)
x_field = np.concatenate([x1_s[..., None], x2_s[..., None]], axis=-1)
x_field = torch.tensor(x_field, dtype=torch.float)
plt.figure(figsize=(8, 8))
plt.title("Target distribution")
plt.xlabel('$z_1$')
plt.ylabel('$z_2$')
plt.contourf(x1_s, x2_s, target_density(x_field))
plt.show()
def show_samples(s):
plt.figure(figsize=(6, 6))
plt.scatter(s[:, 0], s[:, 1], alpha=0.1)
plt.show()
flow = Flow(dim=2, n_flows=10)
shape = (1000, 2)
train_flow(flow, shape, epochs=5000)
z0, zk, ldj, mu, log_var = flow((5000, 2))
show_samples(zk.data)