From 2496661d02032b3789806b29f72c5d50cde70344 Mon Sep 17 00:00:00 2001 From: ritchie46 Date: Fri, 6 Jul 2018 15:37:21 +0200 Subject: [PATCH] init pggan --- .../pggan/notebooks/pggan.ipynb | 547 ++++++++++++++++++ 1 file changed, 547 insertions(+) create mode 100644 generative-adversary-networks/pggan/notebooks/pggan.ipynb diff --git a/generative-adversary-networks/pggan/notebooks/pggan.ipynb b/generative-adversary-networks/pggan/notebooks/pggan.ipynb new file mode 100644 index 0000000..7e0c1b2 --- /dev/null +++ b/generative-adversary-networks/pggan/notebooks/pggan.ipynb @@ -0,0 +1,547 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.utils.data\n", + "import torch.nn.functional as F\n", + "import os\n", + "from torchvision import datasets, transforms\n", + "from tensorboardX import SummaryWriter\n", + "from PIL import Image" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 279, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 3, 8, 8])\n" + ] + } + ], + "source": [ + "class PixelNormLayer(nn.Module):\n", + " \"\"\"\n", + " Pixelwise feature vector normalization.\n", + " \"\"\"\n", + " def __init__(self, eps=1e-8):\n", + " super(PixelNormLayer, self).__init__()\n", + " self.eps = eps\n", + " \n", + " def forward(self, x):\n", + " return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)\n", + "\n", + " def __repr__(self):\n", + " return self.__class__.__name__ + '(eps = %s)' % (self.eps)\n", + " \n", + "\n", + " \n", + "class GInput(nn.Module):\n", + " def __init__(self, alpha=0.2):\n", + " super(GInput, self).__init__()\n", + " self.layer = nn.Sequential(\n", + " nn.ConvTranspose2d(512, 512, kernel_size=4, stride=1, padding=0),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " )\n", + " \n", + " def forward(self, x):\n", + " return self.layer(x)\n", + " \n", + " \n", + "class UpsampleG(nn.Module):\n", + " def __init__(self, alpha=0.2, ch_in=512, ch_out=512):\n", + " super(UpsampleG, self).__init__()\n", + " self.layer = nn.Sequential(\n", + " nn.ConvTranspose2d(ch_in, ch_out, kernel_size=4, stride=2, padding=1),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " nn.ConvTranspose2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " )\n", + " \n", + " def forward(self, x):\n", + " return self.layer(x)\n", + " \n", + " \n", + "class Generator(nn.Module):\n", + " def __init__(self):\n", + " super(Generator, self).__init__()\n", + " self.layer_in = GInput()\n", + " self.layers = [self.layer_in]\n", + " self.net = nn.Sequential(*self.layers)\n", + " self.rgb = None\n", + " self.det_to_rgb()\n", + " \n", + " def det_to_rgb(self):\n", + " c = self.net[-1].layer[3].out_channels\n", + " self.rgb = nn.Sequential(\n", + " nn.Conv2d(c, 3, kernel_size=1, stride=1, padding=0),\n", + " nn.Tanh()\n", + " )\n", + " \n", + " def upsample(self, ch_in=512, ch_out=512):\n", + " self.layers.append(UpsampleG(ch_in=ch_in, ch_out=ch_out))\n", + " self.net = nn.Sequential(*self.layers)\n", + " self.det_to_rgb()\n", + " \n", + " if next(self.parameters()).is_cuda:\n", + " self.to_cuda()\n", + " else:\n", + " self.to_cpu()\n", + " \n", + " def to_cuda(self):\n", + " self.cuda()\n", + " self.net.cuda()\n", + " self.rgb.cuda()\n", + " \n", + " def to_cpu(self):\n", + " self.cpu()\n", + " self.net.cpu()\n", + " self.rgb.cpu()\n", + " \n", + " def forward(self, x, alpha=None):\n", + " if alpha is not None and alpha < 0.999:\n", + " x = self.net[:-1](x)\n", + " x_left = F.upsample(x, scale_factor=2, mode='nearest')\n", + " x_right = self.net[-1](x)\n", + " \n", + " alpha = torch.tensor(alpha)\n", + " one = torch.ones(1)\n", + " if x.is_cuda:\n", + " alpha = alpha.cuda()\n", + " one = one.cuda()\n", + " \n", + " x = (one - alpha) * x_left + alpha * x_right\n", + " \n", + " else:\n", + " x = self.net(x) \n", + " x = self.rgb(x)\n", + " print(x.shape)\n", + " \n", + "a = Generator()\n", + "a.upsample()\n", + "\n", + "a(torch.ones(1, 512, 1, 1), 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 328, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0.3352])" + ] + }, + "execution_count": 328, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class UpsampleD(nn.Module):\n", + " def __init__(self, alpha=0.2, ch_in=512, ch_out=512):\n", + " super(UpsampleD, self).__init__()\n", + " self.layer = nn.Sequential(\n", + " nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " nn.AvgPool2d(2)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " return self.layer(x)\n", + "\n", + "class DOutput(nn.Module):\n", + " def __init__(self, alpha=0.2):\n", + " super(DOutput, self).__init__()\n", + " self.layer = nn.Sequential(\n", + " nn.Conv2d(513, 512, kernel_size=3, stride=1, padding=1),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0),\n", + " PixelNormLayer(),\n", + " nn.LeakyReLU(alpha),\n", + " )\n", + " self.out = nn.Linear(512, 1)\n", + " \n", + " def forward(self, x):\n", + " std = x.std()\n", + " ones = torch.ones((1, 1, 4, 4))\n", + " \n", + " if x.is_cuda:\n", + " ones = ones.cuda()\n", + " \n", + " x = torch.cat((x, (ones * std).float()), dim=1)\n", + " \n", + " x = self.layer(x)\n", + " return self.out(x.view(512))\n", + "\n", + "class Discriminator(nn.Module):\n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " self.layer_out = DOutput()\n", + " self.layers = [self.layer_out]\n", + " self.net = nn.Sequential(*self.layers)\n", + " self.rgb = None\n", + " self.rgb_previous_size = None\n", + " self.det_from_rgb()\n", + " \n", + " def det_from_rgb(self):\n", + " c = self.net[0].layer[0].in_channels\n", + " c = 512 if c == 513 else c\n", + " \n", + " self.rgb_previous_size = self.rgb\n", + " self.rgb = nn.Conv2d(3, c, kernel_size=1, stride=1, padding=0, bias=False)\n", + " \n", + " def upsample(self, ch_in=512, ch_out=512):\n", + " self.layers.insert(0, UpsampleD(ch_in=ch_in, ch_out=ch_out))\n", + " self.net = nn.Sequential(*self.layers)\n", + " self.det_from_rgb()\n", + " \n", + " if next(self.parameters()).is_cuda:\n", + " self.to_cuda()\n", + " else:\n", + " self.to_cpu()\n", + " \n", + " def to_cuda(self):\n", + " self.cuda()\n", + " self.net.cuda()\n", + " self.rgb.cuda()\n", + " \n", + " def to_cpu(self):\n", + " self.cpu()\n", + " self.net.cpu()\n", + " self.rgb.cpu()\n", + " \n", + " def forward(self, x, alpha=None):\n", + " if alpha is not None and alpha < 0.999:\n", + " x_left = F.avg_pool2d(x, 2)\n", + " x_left = self.rgb_previous_size(x_left)\n", + " x_right = self.rgb(x)\n", + " x_right = self.net[0](x_right)\n", + " \n", + " alpha = torch.tensor(alpha)\n", + " one = torch.ones(1)\n", + " if x.is_cuda:\n", + " alpha = alpha.cuda()\n", + " one = one.cuda()\n", + " \n", + " x = (one - alpha) * x_left + alpha * x_right\n", + " x = self.net[1:](x)\n", + " else:\n", + " x = self.rgb(x)\n", + " x = self.net(x)\n", + " \n", + " if self.training:\n", + " return x\n", + " \n", + " return F.sigmoid(x)\n", + " \n", + "a = Discriminator()\n", + "a.upsample()\n", + "a.upsample()\n", + "a(torch.rand(1, 3, 16, 16), 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def det_loss(gen, input_z, dis, input_x):\n", + " y_real = dis(input_x)\n", + " generated = gen(input_z)\n", + " y_fake = dis(generated)\n", + " \n", + " ones = torch.ones_like(y_real)\n", + " zeros = torch.zeros_like(y_fake)\n", + " if next(gen.parameters()).is_cuda:\n", + " ones = ones.cuda()\n", + " zeros = zeros.cuda()\n", + " \n", + " loss_real = F.binary_cross_entropy_with_logits(y_real, ones)\n", + " loss_fake = F.binary_cross_entropy_with_logits(y_fake, zeros)\n", + " \n", + " loss_dis = loss_real + loss_fake\n", + " loss_gen = F.binary_cross_entropy_with_logits(y_fake, ones)\n", + " \n", + " return loss_dis, loss_gen, generated\n", + " \n", + "det_loss(gen, torch.ones((1, 100)), dis, torch.ones(1, 3, 64, 64))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def show_img(x, real=True):\n", + " plt.figure(figsize=(6, 6))\n", + " if isinstance(x, torch.Tensor):\n", + " if next(gen.parameters()).is_cuda:\n", + " x = x.cpu()\n", + " x = x.data.numpy()\n", + " \n", + " x = np.transpose(np.squeeze(x), [1, 2, 0]) \n", + " if not real:\n", + " x = np.array((x + 1) / 2 * 255, int)\n", + "\n", + " plt.imshow(x)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "class ImageFolderEX(datasets.ImageFolder):\n", + " def __getitem__(self, index):\n", + " def get_img(index):\n", + " path, label = self.imgs[index]\n", + " try:\n", + " img = self.loader(os.path.join(self.root, path))\n", + " except:\n", + " img = get_img(index + 1)\n", + " return img\n", + " img = get_img(index)\n", + " return self.transform(img)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "trans = transforms.Compose([\n", + " transforms.RandomHorizontalFlip(0.5),\n", + " transforms.RandomResizedCrop(64, scale=(0.4, 0.8), ratio=(1, 1)),\n", + "# transforms.Resize((64, 64), interpolation=2),\n", + " transforms.ToTensor(),\n", + "])\n", + "\n", + "data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=1, shuffle=True, drop_last=True, num_workers=0)\n", + "x = next(iter(data))\n", + "# show_img(x[0], False)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "def save_checkpoint(state, filename='checkpoint.pth.tar'):\n", + " torch.save(state, filename)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "dis = Discriminator().cuda()\n", + "gen = Generator().cuda()\n", + "writer = SummaryWriter(log_dir='tb/7')\n", + "\n", + "state = torch.load('gen.pth')\n", + "gen.load_state_dict(state['state_dict'])\n", + "state = torch.load('dis.pth')\n", + "dis.load_state_dict(state['state_dict'])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "lr = 0.0002\n", + "beta_1 = 0.5\n", + "optimizer_gen = torch.optim.Adam(gen.parameters(), lr, betas=(beta_1, 0.999))\n", + "optimizer_dis = torch.optim.Adam(dis.parameters(), lr, betas=(beta_1, 0.999))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['epoch', 'state_dict', 'optimizer', 'total_step'])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "total_step = state['total_step']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epochs = 20\n", + "batch_size = 256\n", + "\n", + "data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)\n", + "\n", + "for epoch in range(epochs):\n", + " \n", + " c = 0\n", + " for x in iter(data):\n", + " c += 1\n", + " dis.zero_grad()\n", + " \n", + " z = torch.tensor(np.random.uniform(-1, 1, (batch_size, 100)), dtype=torch.float32)\n", + " \n", + " if next(gen.parameters()).is_cuda:\n", + " x = x.cuda()\n", + " z = z.cuda()\n", + " \n", + " y_real = dis(x)\n", + " generated = gen(z)\n", + " y_fake = dis(generated)\n", + " \n", + " ones = torch.ones_like(y_real)\n", + " zeros = torch.zeros_like(y_fake)\n", + " if next(gen.parameters()).is_cuda:\n", + " ones = ones.cuda()\n", + " zeros = zeros.cuda()\n", + " \n", + " loss_real = F.binary_cross_entropy_with_logits(y_real, ones)\n", + " loss_fake = F.binary_cross_entropy_with_logits(y_fake, zeros)\n", + " loss_dis = loss_real + loss_fake\n", + " \n", + " loss_dis.backward()\n", + " optimizer_dis.step()\n", + " \n", + " gen.zero_grad()\n", + " generated = gen(z)\n", + " y_fake = dis(generated)\n", + " loss_gen = F.binary_cross_entropy_with_logits(y_fake, ones)\n", + " loss_gen.backward()\n", + " optimizer_gen.step()\n", + " \n", + " global_step = total_step + epoch * len(data) + c\n", + " \n", + " if c % 1 == 0:\n", + " writer.add_scalar('loss_dis', loss_dis.item(), global_step)\n", + " writer.add_scalar('loss_gen', loss_gen.item(), global_step)\n", + " \n", + " if c % 4 == 0:\n", + " print(loss_dis.item(), loss_gen.item(), 'step', global_step)\n", + " writer.add_image('img', generated[0], global_step)\n", + " print('finished epoch', epoch)\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "save_checkpoint({\n", + " 'epoch': epoch + 1,\n", + " 'state_dict': dis.state_dict(),\n", + " 'optimizer' : optimizer_dis.state_dict(),\n", + " 'total_step': global_step,\n", + " }, 'dis.pth')\n", + "\n", + "save_checkpoint({\n", + " 'epoch': epoch + 1,\n", + " 'state_dict': gen.state_dict(),\n", + " 'optimizer' : optimizer_gen.state_dict(),\n", + " 'total_step': global_step,\n", + " }, 'gen.pth')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}