This commit is contained in:
ritchie46 2018-07-17 12:03:41 +02:00
parent 2496661d02
commit b7a14df1c2

View File

@ -0,0 +1,448 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.utils.data\n",
"import torch.nn.functional as F\n",
"import os\n",
"from torchvision import datasets, transforms\n",
"from tensorboardX import SummaryWriter\n",
"from PIL import Image\n",
"import logging"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 3, 256, 256])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Generator(nn.Module):\n",
" def __init__(self, input_size=200, alpha=0.2):\n",
" super(Generator, self).__init__() \n",
" kernel_size = 4\n",
" padding = 1\n",
" stride = 2\n",
" \n",
" self.input = nn.Linear(input_size, 4 * 4 * 1024)\n",
" self.net = nn.Sequential(\n",
" nn.BatchNorm2d(1024),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(1024, 512, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(512, 512, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(512, 512, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(512, 256, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(256),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(256, 128, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(128),\n",
" nn.LeakyReLU(alpha),\n",
" nn.ConvTranspose2d(128, 3, kernel_size, stride, padding),\n",
" nn.Tanh()\n",
" )\n",
" \n",
" def forward(self, z):\n",
" x = self.input(z)\n",
" return self.net(x.view(-1, 1024, 4, 4))\n",
" \n",
"gen = Generator()\n",
"gen(torch.ones((1, 200))).shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.3871]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Discriminator(nn.Module):\n",
" def __init__(self, alpha=0.2):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" kernel_size = 4\n",
" padding = 1\n",
" stride = 2\n",
" \n",
" self.net = nn.Sequential(\n",
" nn.Conv2d(3, 128, kernel_size, stride, padding),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(128, 256, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(256),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(256, 512, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(512, 512, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(512, 512, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(512),\n",
" nn.LeakyReLU(alpha),\n",
" nn.Conv2d(512, 1024, kernel_size, stride, padding),\n",
" nn.BatchNorm2d(1024),\n",
" nn.LeakyReLU(alpha),\n",
" )\n",
" self.output = nn.Linear(4 * 4 * 1024, 1)\n",
" \n",
" \n",
" def forward(self, x):\n",
" x = self.net(x)\n",
" x = torch.reshape(x, (-1, 4 * 4 * 1024))\n",
" x = self.output(x)\n",
" \n",
" if self.training:\n",
" return x\n",
" \n",
" return F.sigmoid(x)\n",
" \n",
"dis = Discriminator()\n",
"dis(torch.ones(1, 3, 256, 256))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class ImageFolderEX(datasets.ImageFolder):\n",
" def __getitem__(self, index):\n",
" def get_img(index):\n",
" path, label = self.imgs[index]\n",
" try:\n",
" img = self.loader(os.path.join(self.root, path))\n",
" except:\n",
" img = get_img(index + 1)\n",
" return img\n",
" img = get_img(index)\n",
" return self.transform(img) * 2 - 1"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Found 0 files in subfolders of: .\nSupported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-8-dc3c12e3e15a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m ])\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mImageFolderEX\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrop_last\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/miniconda3/lib/python3.6/site-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, transform, target_transform, loader)\u001b[0m\n\u001b[1;32m 176\u001b[0m super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,\n\u001b[1;32m 177\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 178\u001b[0;31m target_transform=target_transform)\n\u001b[0m\u001b[1;32m 179\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/miniconda3/lib/python3.6/site-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, loader, extensions, transform, target_transform)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m raise(RuntimeError(\"Found 0 files in subfolders of: \" + root + \"\\n\"\n\u001b[0;32m---> 79\u001b[0;31m \"Supported extensions are: \" + \",\".join(extensions)))\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Found 0 files in subfolders of: .\nSupported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif"
]
}
],
"source": [
"trans = transforms.Compose([\n",
" transforms.Resize((256, 256), interpolation=2),\n",
" transforms.ToTensor(),\n",
"])\n",
"\n",
"data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=1, shuffle=True, drop_last=True, num_workers=0)\n",
"x = next(iter(data))\n",
"print(x.min(), x.max())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def train_dis(dis, gen, x):\n",
" z = torch.tensor(np.random.normal(0, 1, (batch_size, 200)), dtype=torch.float32)\n",
"\n",
" if next(gen.parameters()).is_cuda:\n",
" x = x.cuda()\n",
" z = z.cuda()\n",
"\n",
" # train discriminator\n",
" dis.zero_grad()\n",
" y_real_pred = dis(x)\n",
" \n",
" idx = np.random.uniform(0, 1, y_real_pred.shape)\n",
" idx = np.argwhere(idx < 0.03)\n",
" \n",
" # labels\n",
" ones = np.ones(y_real_pred.shape) + np.random.uniform(-0.1, 0.1)\n",
" ones[idx] = 0\n",
" \n",
" zeros = np.zeros(y_real_pred.shape) + np.random.uniform(0, 0.2)\n",
" zeros[idx] = 1\n",
" ones = torch.from_numpy(ones).float()\n",
" zeros = torch.from_numpy(zeros).float()\n",
"\n",
" if next(gen.parameters()).is_cuda:\n",
" ones = ones.cuda()\n",
" zeros = zeros.cuda()\n",
"\n",
" loss_real = F.binary_cross_entropy_with_logits(y_real_pred, ones)\n",
"\n",
" generated = gen(z)\n",
" y_fake_pred = dis(generated)\n",
"\n",
" loss_fake = F.binary_cross_entropy_with_logits(y_fake_pred, zeros)\n",
" loss = loss_fake + loss_real\n",
" loss.backward()\n",
" optimizer_dis.step()\n",
" return loss\n",
"\n",
" \n",
"def train_gen(gen, batch_size):\n",
" z = torch.tensor(np.random.normal(0, 1, (batch_size, 200)), dtype=torch.float32)\n",
" \n",
" if next(gen.parameters()).is_cuda:\n",
" z = z.cuda()\n",
" \n",
" gen.zero_grad()\n",
" generated = gen(z)\n",
" y_fake = dis(generated)\n",
"\n",
" ones = torch.ones_like(y_fake)\n",
" if next(gen.parameters()).is_cuda:\n",
" ones = ones.cuda()\n",
"\n",
" loss = F.binary_cross_entropy_with_logits(y_fake, ones)\n",
" loss.backward()\n",
" optimizer_gen.step()\n",
" return loss, generated"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def show_img(x, real=True, title=None, figsize=(6, 6)):\n",
" plt.figure(figsize=figsize)\n",
" if title:\n",
" plt.title(title)\n",
" if isinstance(x, torch.Tensor):\n",
" if next(gen.parameters()).is_cuda:\n",
" x = x.cpu()\n",
" x = x.data.numpy()\n",
" \n",
" x = np.transpose(np.squeeze(x), [1, 2, 0]) \n",
" if not real:\n",
" x = np.array((x + 1) / 2 * 255, int)\n",
" plt.imshow(x)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"logger = logging.getLogger()\n",
"logger.addHandler(logging.FileHandler(\"256.log\"))\n",
"logger.setLevel(logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def save_checkpoint(state, filename='checkpoint.pth.tar'):\n",
" torch.save(state, filename)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"dis = Discriminator().cuda()\n",
"gen = Generator().cuda()\n",
"gen = Generator()\n",
"\n",
"lr = 0.0002\n",
"beta_1 = 0.5\n",
"beta_2 = 0.999\n",
"optimizer_gen = torch.optim.Adam(gen.parameters(), lr, betas=(beta_1, beta_2))\n",
"optimizer_dis = torch.optim.Adam(dis.parameters(), lr, betas=(beta_1, beta_2))\n",
"\n",
"epochs = 20\n",
"batch_size = 64\n",
"data = torch.utils.data.DataLoader(ImageFolderEX('.', trans), batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)\n",
"\n",
"\n",
"from_checkpoint = True\n",
"if from_checkpoint:\n",
" state = torch.load('gen5.pth')\n",
" gen.load_state_dict(state['state_dict'])\n",
" state = torch.load('backup/dis1.pth')\n",
" dis.load_state_dict(state['state_dict'])\n",
" total_step = state['total_step']\n",
"else:\n",
" total_step = 0\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def save():\n",
" save_checkpoint({\n",
" 'epoch': epoch,\n",
" 'state_dict': dis.state_dict(),\n",
" 'optimizer' : optimizer_dis.state_dict(),\n",
" 'total_step': global_step,\n",
" }, f'backup/dis{epoch}.pth')\n",
"\n",
" save_checkpoint({\n",
" 'epoch': epoch,\n",
" 'state_dict': gen.state_dict(),\n",
" 'optimizer' : optimizer_gen.state_dict(),\n",
" 'total_step': global_step,\n",
" }, f'backup/gen{epoch}.pth')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ritchie46/miniconda3/lib/python3.6/site-packages/PIL/Image.py:2514: DecompressionBombWarning: Image size (95799284 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
" DecompressionBombWarning)\n",
"/home/ritchie46/miniconda3/lib/python3.6/site-packages/PIL/Image.py:2514: DecompressionBombWarning: Image size (145486286 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
" DecompressionBombWarning)\n",
"/home/ritchie46/miniconda3/lib/python3.6/site-packages/PIL/Image.py:2514: DecompressionBombWarning: Image size (94435468 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
" DecompressionBombWarning)\n",
"/home/ritchie46/miniconda3/lib/python3.6/site-packages/PIL/Image.py:2514: DecompressionBombWarning: Image size (107327830 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
" DecompressionBombWarning)\n"
]
}
],
"source": [
"n = len(data)\n",
"for epoch in range(0, epochs):\n",
" c = 0\n",
" n = len(data) \n",
" writer = SummaryWriter(log_dir=f'tb/epoch_{epoch}')\n",
"\n",
" for x in iter(data): \n",
" c += 1\n",
"\n",
" loss_dis = train_dis(dis, gen, x)\n",
" loss_gen, generated = train_gen(gen, batch_size)\n",
" \n",
" global_step = epoch * n + c\n",
" if c % 1 == 0:\n",
" writer.add_scalar('loss_dis', loss_dis.item(), total_step + global_step)\n",
" writer.add_scalar('loss_gen', loss_gen.item(), total_step + global_step)\n",
" if c % 4 == 0:\n",
" msg = f'loss: {loss_dis.item()}, \\t {loss_gen.item()} \\t epoch: {epoch}, \\t {c}/{n}'\n",
" logging.info(msg)\n",
" \n",
" if c % 2 == 0:\n",
" writer.add_image('img', torch.sigmoid(generated[:5]), total_step + global_step)\n",
" \n",
" if c > (n // 2):\n",
" save()\n",
" \n",
" print('finished epoch', epoch)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}