import torch from torch import nn class Planar(nn.Module): def __init__(self, size=1, init_sigma=0.01): super().__init__() self.u = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma)) self.w = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma)) self.b = nn.Parameter(torch.zeros(1)) @property def normalized_u(self): """ Needed for invertibility condition. See Appendix A.1 Rezende et al. Variational Inference with Normalizing Flows https://arxiv.org/pdf/1505.05770.pdf """ # softplus def m(x): return -1 + torch.log(1 + torch.exp(x)) wtu = torch.matmul(self.w, self.u.t()) w_div_w2 = self.w / torch.norm(self.w) return self.u + (m(wtu) - wtu) * w_div_w2 def psi(self, z): """ ψ(z) =h′(w^tz+b)w See eq(11) Rezende et al. Variational Inference with Normalizing Flows https://arxiv.org/pdf/1505.05770.pdf """ return self.h_prime(z @ self.w.t() + self.b) @ self.w def h(self, x): return torch.tanh(x) def h_prime(self, z): return 1 - torch.tanh(z) ** 2 def forward(self, z): if isinstance(z, tuple): z, accumulating_ldj = z else: z, accumulating_ldj = z, 0 psi = self.psi(z) u = self.normalized_u # determinant of jacobian det = (1 + psi @ u.t()) # log |det Jac| ldj = torch.log(torch.abs(det) + 1e-6) wzb = z @ self.w.t() + self.b fz = z + (u * self.h(wzb)) return fz, ldj + accumulating_ldj if __name__ == '__main__': import matplotlib.pyplot as plt z0 = torch.rand((1000, 2)) with torch.no_grad(): pf = Planar(size=2) zk = z0 for i in range(10): zk, ldj = pf.forward(zk) plt.scatter(zk[:, 0], zk[:, 1], alpha=0.2) plt.show()