init pggan

This commit is contained in:
ritchie46
2018-07-06 15:37:21 +02:00
parent 4802ed19b8
commit 2496661d02

View File

@@ -0,0 +1,547 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.utils.data\n",
"import torch.nn.functional as F\n",
"import os\n",
"from torchvision import datasets, transforms\n",
"from tensorboardX import SummaryWriter\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 279,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 3, 8, 8])\n"
]
}
],
"source": [
"class PixelNormLayer(nn.Module):\n",
" \"\"\"\n",
" Pixelwise feature vector normalization.\n",
" \"\"\"\n",
" def __init__(self, eps=1e-8):\n",
" super(PixelNormLayer, self).__init__()\n",
" self.eps = eps\n",
" \n",
" def forward(self, x):\n",
" return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)\n",
"\n",
" def __repr__(self):\n",
" return self.__class__.__name__ + '(eps = %s)' % (self.eps)\n",
" \n",
"\n",
" \n",
"class GInput(nn.Module):\n",
" def __init__(self, alpha=0.2):\n",
" super(GInput, self).__init__()\n",
" self.layer = nn.Sequential(\n",
" nn.ConvTranspose2d(512, 512, kernel_size=4, stride=1, padding=0),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" )\n",
" \n",
" def forward(self, x):\n",
" return self.layer(x)\n",
" \n",
" \n",
"class UpsampleG(nn.Module):\n",
" def __init__(self, alpha=0.2, ch_in=512, ch_out=512):\n",
" super(UpsampleG, self).__init__()\n",
" self.layer = nn.Sequential(\n",
" nn.ConvTranspose2d(ch_in, ch_out, kernel_size=4, stride=2, padding=1),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" )\n",
" \n",
" def forward(self, x):\n",
" return self.layer(x)\n",
" \n",
" \n",
"class Generator(nn.Module):\n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
" self.layer_in = GInput()\n",
" self.layers = [self.layer_in]\n",
" self.net = nn.Sequential(*self.layers)\n",
" self.rgb = None\n",
" self.det_to_rgb()\n",
" \n",
" def det_to_rgb(self):\n",
" c = self.net[-1].layer[3].out_channels\n",
" self.rgb = nn.Sequential(\n",
" nn.Conv2d(c, 3, kernel_size=1, stride=1, padding=0),\n",
" nn.Tanh()\n",
" )\n",
" \n",
" def upsample(self, ch_in=512, ch_out=512):\n",
" self.layers.append(UpsampleG(ch_in=ch_in, ch_out=ch_out))\n",
" self.net = nn.Sequential(*self.layers)\n",
" self.det_to_rgb()\n",
" \n",
" if next(self.parameters()).is_cuda:\n",
" self.to_cuda()\n",
" else:\n",
" self.to_cpu()\n",
" \n",
" def to_cuda(self):\n",
" self.cuda()\n",
" self.net.cuda()\n",
" self.rgb.cuda()\n",
" \n",
" def to_cpu(self):\n",
" self.cpu()\n",
" self.net.cpu()\n",
" self.rgb.cpu()\n",
" \n",
" def forward(self, x, alpha=None):\n",
" if alpha is not None and alpha < 0.999:\n",
" x = self.net[:-1](x)\n",
" x_left = F.upsample(x, scale_factor=2, mode='nearest')\n",
" x_right = self.net[-1](x)\n",
" \n",
" alpha = torch.tensor(alpha)\n",
" one = torch.ones(1)\n",
" if x.is_cuda:\n",
" alpha = alpha.cuda()\n",
" one = one.cuda()\n",
" \n",
" x = (one - alpha) * x_left + alpha * x_right\n",
" \n",
" else:\n",
" x = self.net(x) \n",
" x = self.rgb(x)\n",
" print(x.shape)\n",
" \n",
"a = Generator()\n",
"a.upsample()\n",
"\n",
"a(torch.ones(1, 512, 1, 1), 1)"
]
},
{
"cell_type": "code",
"execution_count": 328,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0.3352])"
]
},
"execution_count": 328,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class UpsampleD(nn.Module):\n",
" def __init__(self, alpha=0.2, ch_in=512, ch_out=512):\n",
" super(UpsampleD, self).__init__()\n",
" self.layer = nn.Sequential(\n",
" nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" nn.AvgPool2d(2)\n",
" )\n",
" \n",
" def forward(self, x):\n",
" return self.layer(x)\n",
"\n",
"class DOutput(nn.Module):\n",
" def __init__(self, alpha=0.2):\n",
" super(DOutput, self).__init__()\n",
" self.layer = nn.Sequential(\n",
" nn.Conv2d(513, 512, kernel_size=3, stride=1, padding=1),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0),\n",
" PixelNormLayer(),\n",
" nn.LeakyReLU(alpha),\n",
" )\n",
" self.out = nn.Linear(512, 1)\n",
" \n",
" def forward(self, x):\n",
" std = x.std()\n",
" ones = torch.ones((1, 1, 4, 4))\n",
" \n",
" if x.is_cuda:\n",
" ones = ones.cuda()\n",
" \n",
" x = torch.cat((x, (ones * std).float()), dim=1)\n",
" \n",
" x = self.layer(x)\n",
" return self.out(x.view(512))\n",
"\n",
"class Discriminator(nn.Module):\n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" self.layer_out = DOutput()\n",
" self.layers = [self.layer_out]\n",
" self.net = nn.Sequential(*self.layers)\n",
" self.rgb = None\n",
" self.rgb_previous_size = None\n",
" self.det_from_rgb()\n",
" \n",
" def det_from_rgb(self):\n",
" c = self.net[0].layer[0].in_channels\n",
" c = 512 if c == 513 else c\n",
" \n",
" self.rgb_previous_size = self.rgb\n",
" self.rgb = nn.Conv2d(3, c, kernel_size=1, stride=1, padding=0, bias=False)\n",
" \n",
" def upsample(self, ch_in=512, ch_out=512):\n",
" self.layers.insert(0, UpsampleD(ch_in=ch_in, ch_out=ch_out))\n",
" self.net = nn.Sequential(*self.layers)\n",
" self.det_from_rgb()\n",
" \n",
" if next(self.parameters()).is_cuda:\n",
" self.to_cuda()\n",
" else:\n",
" self.to_cpu()\n",
" \n",
" def to_cuda(self):\n",
" self.cuda()\n",
" self.net.cuda()\n",
" self.rgb.cuda()\n",
" \n",
" def to_cpu(self):\n",
" self.cpu()\n",
" self.net.cpu()\n",
" self.rgb.cpu()\n",
" \n",
" def forward(self, x, alpha=None):\n",
" if alpha is not None and alpha < 0.999:\n",
" x_left = F.avg_pool2d(x, 2)\n",
" x_left = self.rgb_previous_size(x_left)\n",
" x_right = self.rgb(x)\n",
" x_right = self.net[0](x_right)\n",
" \n",
" alpha = torch.tensor(alpha)\n",
" one = torch.ones(1)\n",
" if x.is_cuda:\n",
" alpha = alpha.cuda()\n",
" one = one.cuda()\n",
" \n",
" x = (one - alpha) * x_left + alpha * x_right\n",
" x = self.net[1:](x)\n",
" else:\n",
" x = self.rgb(x)\n",
" x = self.net(x)\n",
" \n",
" if self.training:\n",
" return x\n",
" \n",
" return F.sigmoid(x)\n",
" \n",
"a = Discriminator()\n",
"a.upsample()\n",
"a.upsample()\n",
"a(torch.rand(1, 3, 16, 16), 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def det_loss(gen, input_z, dis, input_x):\n",
" y_real = dis(input_x)\n",
" generated = gen(input_z)\n",
" y_fake = dis(generated)\n",
" \n",
" ones = torch.ones_like(y_real)\n",
" zeros = torch.zeros_like(y_fake)\n",
" if next(gen.parameters()).is_cuda:\n",
" ones = ones.cuda()\n",
" zeros = zeros.cuda()\n",
" \n",
" loss_real = F.binary_cross_entropy_with_logits(y_real, ones)\n",
" loss_fake = F.binary_cross_entropy_with_logits(y_fake, zeros)\n",
" \n",
" loss_dis = loss_real + loss_fake\n",
" loss_gen = F.binary_cross_entropy_with_logits(y_fake, ones)\n",
" \n",
" return loss_dis, loss_gen, generated\n",
" \n",
"det_loss(gen, torch.ones((1, 100)), dis, torch.ones(1, 3, 64, 64))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def show_img(x, real=True):\n",
" plt.figure(figsize=(6, 6))\n",
" if isinstance(x, torch.Tensor):\n",
" if next(gen.parameters()).is_cuda:\n",
" x = x.cpu()\n",
" x = x.data.numpy()\n",
" \n",
" x = np.transpose(np.squeeze(x), [1, 2, 0]) \n",
" if not real:\n",
" x = np.array((x + 1) / 2 * 255, int)\n",
"\n",
" plt.imshow(x)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class ImageFolderEX(datasets.ImageFolder):\n",
" def __getitem__(self, index):\n",
" def get_img(index):\n",
" path, label = self.imgs[index]\n",
" try:\n",
" img = self.loader(os.path.join(self.root, path))\n",
" except:\n",
" img = get_img(index + 1)\n",
" return img\n",
" img = get_img(index)\n",
" return self.transform(img)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"trans = transforms.Compose([\n",
" transforms.RandomHorizontalFlip(0.5),\n",
" transforms.RandomResizedCrop(64, scale=(0.4, 0.8), ratio=(1, 1)),\n",
"# transforms.Resize((64, 64), interpolation=2),\n",
" transforms.ToTensor(),\n",
"])\n",
"\n",
"data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=1, shuffle=True, drop_last=True, num_workers=0)\n",
"x = next(iter(data))\n",
"# show_img(x[0], False)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"def save_checkpoint(state, filename='checkpoint.pth.tar'):\n",
" torch.save(state, filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"dis = Discriminator().cuda()\n",
"gen = Generator().cuda()\n",
"writer = SummaryWriter(log_dir='tb/7')\n",
"\n",
"state = torch.load('gen.pth')\n",
"gen.load_state_dict(state['state_dict'])\n",
"state = torch.load('dis.pth')\n",
"dis.load_state_dict(state['state_dict'])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"lr = 0.0002\n",
"beta_1 = 0.5\n",
"optimizer_gen = torch.optim.Adam(gen.parameters(), lr, betas=(beta_1, 0.999))\n",
"optimizer_dis = torch.optim.Adam(dis.parameters(), lr, betas=(beta_1, 0.999))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['epoch', 'state_dict', 'optimizer', 'total_step'])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state.keys()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"total_step = state['total_step']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"epochs = 20\n",
"batch_size = 256\n",
"\n",
"data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)\n",
"\n",
"for epoch in range(epochs):\n",
" \n",
" c = 0\n",
" for x in iter(data):\n",
" c += 1\n",
" dis.zero_grad()\n",
" \n",
" z = torch.tensor(np.random.uniform(-1, 1, (batch_size, 100)), dtype=torch.float32)\n",
" \n",
" if next(gen.parameters()).is_cuda:\n",
" x = x.cuda()\n",
" z = z.cuda()\n",
" \n",
" y_real = dis(x)\n",
" generated = gen(z)\n",
" y_fake = dis(generated)\n",
" \n",
" ones = torch.ones_like(y_real)\n",
" zeros = torch.zeros_like(y_fake)\n",
" if next(gen.parameters()).is_cuda:\n",
" ones = ones.cuda()\n",
" zeros = zeros.cuda()\n",
" \n",
" loss_real = F.binary_cross_entropy_with_logits(y_real, ones)\n",
" loss_fake = F.binary_cross_entropy_with_logits(y_fake, zeros)\n",
" loss_dis = loss_real + loss_fake\n",
" \n",
" loss_dis.backward()\n",
" optimizer_dis.step()\n",
" \n",
" gen.zero_grad()\n",
" generated = gen(z)\n",
" y_fake = dis(generated)\n",
" loss_gen = F.binary_cross_entropy_with_logits(y_fake, ones)\n",
" loss_gen.backward()\n",
" optimizer_gen.step()\n",
" \n",
" global_step = total_step + epoch * len(data) + c\n",
" \n",
" if c % 1 == 0:\n",
" writer.add_scalar('loss_dis', loss_dis.item(), global_step)\n",
" writer.add_scalar('loss_gen', loss_gen.item(), global_step)\n",
" \n",
" if c % 4 == 0:\n",
" print(loss_dis.item(), loss_gen.item(), 'step', global_step)\n",
" writer.add_image('img', generated[0], global_step)\n",
" print('finished epoch', epoch)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"save_checkpoint({\n",
" 'epoch': epoch + 1,\n",
" 'state_dict': dis.state_dict(),\n",
" 'optimizer' : optimizer_dis.state_dict(),\n",
" 'total_step': global_step,\n",
" }, 'dis.pth')\n",
"\n",
"save_checkpoint({\n",
" 'epoch': epoch + 1,\n",
" 'state_dict': gen.state_dict(),\n",
" 'optimizer' : optimizer_gen.state_dict(),\n",
" 'total_step': global_step,\n",
" }, 'gen.pth')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}