研究2:PGGANを用いたAIによる画像生成の体験
(DCGANの限界サイズ(64✕64ピクセル)を超えた風景画の学習結果を出力)
・背景と目的
・アプローチ手法
・原理
・可視化方法
・可視化結果
・考察
【アプローチ手法】
PGGANは、Progressive Growing of GANsの略で、生成的対抗ネットワークの新しいトレーニング手法です。この手法は、低解像度から始めて、段階的に新しいレイヤーを追加して、トレーニングが進むにつれてますます細かい詳細をモデル化することで、ジェネレーター(生成器)とディスクリミネーター(判別器)を進化させることを特徴としています。これにより、トレーニングを高速化し、画像の品質、安定性、バリエーションを向上させることができます。この手法により、1024×1024の画像を生成することも可能になります。
【コード】256×256ピクセルの画像を出力
参考サイト:https://qiita.com/kosakae256/items/cf3293861c14956cce18
import torch from torch import nn from torch.nn import functional as F import torchvision import torchvision.transforms as transforms import cv2 import numpy as np import random from time import sleep import os from torchvision.datasets.folder import ImageFolder from torchvision import transforms import os #GPUのメモリ割当調整 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" torch.cuda.empty_cache() # PixelNormalization Module class PixelNorm(nn.Module): def forward(self, x): eps = 1e-8 return x / ((torch.mean(x**2, dim=1, keepdim=True) + eps) ** 0.5) # equalized larning rate class WeightScale(nn.Module): def forward(self, x, gain=2): c = ( (x.shape[1] * x.shape[2] * x.shape[3]) / 2) **0.5 return x / c # バッチの多様性を考慮 class MiniBatchStd(nn.Module): def forward(self, x): std = torch.std(x, dim=0, keepdim=True) mean = torch.mean(std, dim=(1,2,3), keepdim=True) n,c,h,w = x.shape mean = torch.ones(n,1,h,w, dtype=x.dtype, device=x.device)*mean return torch.cat((x,mean), dim=1) # 畳み込み処理をモジュール化 class Conv2d(nn.Module): def __init__(self, inch, outch, kernel_size, padding=0): super().__init__() self.layers = nn.Sequential( WeightScale(), nn.ReflectionPad2d(padding), nn.Conv2d(inch, outch, kernel_size, padding=0), ) nn.init.kaiming_normal_(self.layers[2].weight) #Heの初期化 def forward(self, x): return self.layers(x) class ResBlock(nn.Module): def __init__(self, inch, outch, kernel_size, padding=0): super().__init__() self.conv1 = Conv2d(inch, outch, 3, padding=1) self.relu1 = nn.LeakyReLU(0.2, inplace=False) self.pixnorm1 = PixelNorm() self.conv2 = Conv2d(outch, outch, 3, padding=1) self.relu2 = nn.LeakyReLU(0.2, inplace=False) self.pixnorm2 = PixelNorm() self.relu3 = nn.LeakyReLU(0.2, inplace=False) self.shortcut = nn.Conv2d(inch, outch, kernel_size=(1, 1), padding=0) def forward(self, x): h = self.conv1(x) h = self.relu1(h) h = self.pixnorm1(h) h = self.conv2(h) h = self.relu2(h) h = self.pixnorm2(h) x = self.shortcut(x) y = self.relu3(h + x) return y # Generatorの連結モデルを定義 class ConvModuleG(nn.Module): ''' Args: out_size: (int), Ex.: 16 (resolution) inch: (int), Ex.: 256 outch: (int), Ex.: 128 ''' def __init__(self, out_size, inch, outch, first=False): super().__init__() if first: layers = [ Conv2d(inch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), Conv2d(outch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), ] else: layers = [ nn.Upsample((out_size, out_size), mode='nearest'), Conv2d(inch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), Conv2d(outch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class ConvModuleD(nn.Module): ''' Args: out_size: (int), Ex.: 16 (resolution) inch: (int), Ex.: 256 outch: (int), Ex.: 128 ''' def __init__(self, out_size, inch, outch, final=False): super().__init__() if final: layers = [ MiniBatchStd(), # final block only Conv2d(inch+1, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), Conv2d(outch, outch, 4, padding=0), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), nn.Conv2d(outch, 1, 1, padding=0), ] else: layers = [ Conv2d(inch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), Conv2d(outch, outch, 3, padding=1), nn.LeakyReLU(0.2, inplace=False), PixelNorm(), nn.AdaptiveAvgPool2d((out_size, out_size)), ] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class Generator(nn.Module): def __init__(self): super().__init__() # conv modules & toRGBs scale = 1 inchs = np.array([512/16,512,512,512,512,256,128], dtype=np.uint32)*scale # inputするレイヤー数(追加分) outchs = np.array([512,512, 512,512,256,128,64], dtype=np.uint32)*scale # outputするレイヤー数(追加分) sizes = np.array([4,8,16,32,64,128,256], dtype=np.uint32) firsts = np.array([True, False, False, False, False, False,False], dtype=np.bool) blocks, toRGBs = [], [] for s, inch, outch, first in zip(sizes, inchs, outchs, firsts): blocks.append(ConvModuleG(s, inch, outch, first)) toRGBs.append(nn.Conv2d(outch, 3, 1, padding=0)) self.blocks = nn.ModuleList(blocks) self.toRGBs = nn.ModuleList(toRGBs) def forward(self, x, res, eps=1e-7): # to image n,c = x.shape x = x.reshape(n,c//16,4,4) # for the highest resolution res = min(res, len(self.blocks)) # get integer by floor nlayer = max(int(res-eps), 0) #print(res,nlayer) for i in range(nlayer): x = self.blocks[i](x) # high resolution x_big = self.blocks[nlayer](x) dst_big = self.toRGBs[nlayer](x_big) if nlayer==0: x = dst_big else: # レイヤー変更時の負荷軽減 # low resolution x_sml = F.interpolate(x, x_big.shape[2:4], mode='nearest') dst_sml = self.toRGBs[nlayer-1](x_sml) alpha = res - int(res-eps) #print(alpha) x = (1-alpha)*dst_sml + alpha*dst_big #return x, n, res return torch.sigmoid(x) class Discriminator(nn.Module): def __init__(self): super().__init__() self.minbatch_std = MiniBatchStd() # conv modules & toRGBs scale = 1 inchs = np.array([512,512,512,512,256,128,64], dtype=np.uint32)*scale outchs = np.array([512,512, 512,512,512, 256,128], dtype=np.uint32)*scale sizes = np.array([1,4,8,16,32,64,128], dtype=np.uint32) finals = np.array([True, False, False, False, False, False,False], dtype=np.bool) blocks, fromRGBs = [], [] for s, inch, outch, final in zip(sizes, inchs, outchs, finals): fromRGBs.append(nn.Conv2d(3, inch, 1, padding=0)) blocks.append(ConvModuleD(s, inch, outch, final=final)) self.fromRGBs = nn.ModuleList(fromRGBs) self.blocks = nn.ModuleList(blocks) def forward(self, x, res): # for the highest resolution res = min(res, len(self.blocks)) # get integer by floor eps = 1e-8 n = max(int(res-eps), 0) # high resolution x_big = self.fromRGBs[n](x) x_big = self.blocks[n](x_big) if n==0: x = x_big else: # low resolution x_sml = F.adaptive_avg_pool2d(x, x_big.shape[2:4]) x_sml = self.fromRGBs[n-1](x_sml) alpha = res - int(res-eps) x = (1-alpha)*x_sml + alpha*x_big for i in range(n): x = self.blocks[n-1-i](x) return x def gradient_penalty(netD, real, fake, res, batch_size, gamma=1): device = real.device alpha = torch.rand(batch_size, 1, 1, 1, requires_grad=True).to(device) x = alpha*real + (1-alpha)*fake d_ = netD.forward(x, res) g = torch.autograd.grad(outputs=d_, inputs=x, grad_outputs=torch.ones(d_.shape).to(device), create_graph=True, retain_graph=True,only_inputs=True)[0] g = g.reshape(batch_size, -1) return ((g.norm(2,dim=1)/gamma-1.0)**2).mean() #6759枚の風景画から学習 if __name__ == '__main__': # device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cpu' print(device) #Gparam = torch.load('C:\\Users\\Hayate\\HayateLab\\PGGAN\\6-netG.pth') netG = Generator().to(device) #try: #netG.load_state_dict(Gparam,strict=False) #except: #pass netD = Discriminator().to(device) #netD.load_state_dict(torch.load('C:\\Users\\Hayate\\HayateLab\\PGGAN\\6-netD.pth')) #Gparam = torch.load('C:\\Users\\Hayate\\HayateLab\\PGGAN\\6-netG_mavg.pth') netG_mavg = Generator().to(device) # moving average #try: # netG_mavg.load_state_dict(Gparam,strict=False) #except: #pass lr = 0.001 optG = torch.optim.Adam(netG.parameters(), lr = lr, betas=(0.0, 0.99)) optD = torch.optim.Adam(netD.parameters(), lr = lr, betas=(0.0, 0.99)) criterion = torch.nn.BCELoss() batch_size = 4 # dataset transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((256,256))]) trainset = torchvision.datasets.ImageFolder(root=os.getcwd() + '\\PGGAN\\', transform=transform) train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True) # training res_steps = [10000,30000,50000,75000,100000,150000,20000] #[4,8,16,32,64,128,256] losses = [] j = 0 #学習深度-pixあたり res_i = 0 # 学習深度-pix nepoch = 1370 res_index = 0 #学習深度-pix # constant random inputs r = random.randint(0,99999999999) torch.manual_seed(256) z0 = torch.randn(16, 512).to(device) #16次元のconstノイズをn個排出 z0 = torch.clamp(z0, -1.,1.) torch.manual_seed(r) beta_gp = 10.0 beta_drift = 0.001 #lr_decay=0.87 #レイヤー変更時の減衰 # attenuation_rate = 0.99 #lr減衰率 #attenuation_timing = 200 #この回数ごとに減衰する(lossの値を監視して規定以上なら) torchvision.datasets.ImageFolder for iepoch in range(nepoch): for i, data in enumerate(train_loader): x, y = data x = x.to(device) res = ((j/res_steps[res_index]) * 1.25 + res_i) res = min(res,res_i+1) ### train generator ### z = torch.randn(batch_size, 512).to(x.device) x_ = netG.forward(z, res) del z d_ = netD.forward(x_, res) # fake lossG = -d_.mean() # WGAN_GP del d_ optG.zero_grad() lossG.backward() optG.step() # update netG_mavg by moving average momentum = 0.995 # remain momentum alpha = min(1.0-(1/(j+1)), momentum) for p_mavg, p in zip(netG_mavg.parameters(), netG.parameters()): p_mavg.data = alpha*p_mavg.data + (1.0-alpha)*p.data ### train discriminator ### z = torch.randn(x.shape[0], 512).to(x.device) x_ = netG.forward(z, res) del z x = F.adaptive_avg_pool2d(x, x_.shape[2:4]) d = netD.forward(x, res) # real d_ = netD.forward(x_, res) # fake loss_real = -1 * d.mean() loss_fake = d_.mean() loss_gp = gradient_penalty(netD, x.data, x_.data, res, x.shape[0]) loss_drift = (d**2).mean() del d_ del d lossD = loss_real + loss_fake + beta_gp*loss_gp + beta_drift*loss_drift optD.zero_grad() lossD.backward() optD.step() print('ep: %02d %04d %04d lossG=%.10f lossD=%.10f' % (iepoch, i, j, lossG.item(), lossD.item())) losses.append([lossG.item(), lossD.item()]) j += 1 netG_mavg.eval() z = torch.randn(16, 512).to(x.device) z = torch.clamp(z, -1.,1.) x_0 = netG_mavg.forward(z0, res) x_ = netG_mavg.forward(z, res) dst = torch.cat((x_0, x_), dim=0) del z,x_0,x_ dst = F.interpolate(dst,(256, 256), mode='nearest') dst = dst.to('cpu').detach().numpy() n, c, h, w = dst.shape dst = dst.reshape(4,8,c,h,w) dst = dst.transpose(0,3,1,4,2) dst = dst.reshape(4*h,8*w,3) dst = np.clip(dst*255., 0, 255).astype(np.uint8) dst = cv2.cvtColor(dst, cv2.COLOR_BGR2RGB) cv2.imwrite('image.jpg', dst) netG_mavg.train() #解像度の切り替わり条件 if res_steps[res_index] == j: PATH = os.getcwd() j = 0 res_index += 1 res_i += 1
実装部分は上記までとなります。