planar flows working
This commit is contained in:
84
bayesian/normalizing_flows/flows.py
Normal file
84
bayesian/normalizing_flows/flows.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Planar(nn.Module):
|
||||
def __init__(self, size=1, init_sigma=0.01):
|
||||
super().__init__()
|
||||
self.u = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
|
||||
self.w = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
|
||||
self.b = nn.Parameter(torch.zeros(1))
|
||||
|
||||
@property
|
||||
def normalized_u(self):
|
||||
"""
|
||||
Needed for invertibility condition.
|
||||
|
||||
See Appendix A.1
|
||||
Rezende et al. Variational Inference with Normalizing Flows
|
||||
https://arxiv.org/pdf/1505.05770.pdf
|
||||
"""
|
||||
|
||||
# softplus
|
||||
def m(x):
|
||||
return -1 + torch.log(1 + torch.exp(x))
|
||||
|
||||
wtu = torch.matmul(self.w, self.u.t())
|
||||
w_div_w2 = self.w / torch.norm(self.w)
|
||||
return self.u + (m(wtu) - wtu) * w_div_w2
|
||||
|
||||
def psi(self, z):
|
||||
"""
|
||||
ψ(z) =h′(w^tz+b)w
|
||||
|
||||
See eq(11)
|
||||
Rezende et al. Variational Inference with Normalizing Flows
|
||||
https://arxiv.org/pdf/1505.05770.pdf
|
||||
"""
|
||||
return self.h_prime(z @ self.w.t() + self.b) @ self.w
|
||||
|
||||
def h(self, x):
|
||||
return torch.tanh(x)
|
||||
|
||||
def h_prime(self, z):
|
||||
return 1 - torch.tanh(z) ** 2
|
||||
|
||||
def forward(self, z):
|
||||
if isinstance(z, tuple):
|
||||
z, accumulating_ldj = z
|
||||
else:
|
||||
z, accumulating_ldj = z, 0
|
||||
psi = self.psi(z)
|
||||
|
||||
u = self.normalized_u
|
||||
|
||||
# determinant of jacobian
|
||||
det = (1 + psi @ u.t())
|
||||
|
||||
# log |det Jac|
|
||||
ldj = torch.log(torch.abs(det) + 1e-6)
|
||||
|
||||
wzb = z @ self.w.t() + self.b
|
||||
|
||||
fz = z + (u * self.h(wzb))
|
||||
|
||||
return fz, ldj + accumulating_ldj
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
z0 = torch.rand((1000, 2))
|
||||
|
||||
with torch.no_grad():
|
||||
pf = Planar(size=2)
|
||||
|
||||
zk = z0
|
||||
for i in range(10):
|
||||
zk, ldj = pf.forward(zk)
|
||||
|
||||
plt.scatter(zk[:, 0], zk[:, 1], alpha=0.2)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user