MINEを使ってinfoGANを実装した
はじめに
Mutual information neural estimation(MINE)を使ってinfoGANを実装したのでメモ.MINEに関してのメモはこちら.
設定
オリジナルのinfoGANと同様にMNISTで離散変数1つ,連続変数2つでgeneratorを学習した.infoGANの特徴である相互情報量最大化の部分をMINEに置き換えた.MINEに関しては雰囲気を知りたかったのでだいぶ簡略化した実装にした. 具体的にはMINEのオリジナルの論文ではKLDが無限に大きくなってしまい勾配のバランスが崩れるため,adaptive gradient clippingを使っていたが,今回はKLDではなくDIM(DIMのメモはこちら)で使われていたJensen-Shannonダイバージェンスバージョンを使ってclippingの実装は省いた.また,statistics networkとgeneratorは別々に最適化するところを,これもDIMと同様同時に最適化するようにした.ただし,DIMではgeneratorとstatistics networkの一部を共有していたが完全に別々にした.
モデルはgeneratorとdiscriminator,3つの変数それぞれに対して相互情報量を計算するためのstatistics networkを用意.目的関数は以下のように定義した.
相互情報量は平均をとった.こうしないと相互情報量のロスに引っ張られて生成画像があんまり綺麗にならなかった.
モデルの構成はその辺に落ちてたinfoGANの実装のモデルを真似た.statistics networkに関してはとりあえず画像をベクトル化して潜在変数とくっつけたものを入力とするMLPとして実装した.後一応,discriminatorにはspectral normalizationを付けてる(なくても多分動くんじゃないかな).
離散変数はonehotベクトルで連続変数とノイズは-1~1の範囲の一様分布からサンプリングした.
実験結果
100エポック学習した結果.100エポックもいらなかった.
離散変数と連続変数の片方を-1から1で動かした場合.
離散変数と連続変数のもう一方を-1から1で動かした場合.
思った以上にうまく動いてびっくり.ただ,3と5の区別が微妙な感じなのと文字の太さはあんまり表現できてない.もう少しちゃんと実装すればもっとうまくいくのかどうか. とりあえず現状の実装だと相互情報量がサチってたからこれ以上は無理な気がする.
コード
適当に転がってたGANのコードを改変して作ったのでいろんなところがハードコードされているけどとりあえず動くはず. 以下メインファイル.
import torch import torch.optim as optim import torch.nn.functional as F import numpy as np import os import torchvision.utils as vutils from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader from model import generator, disciminator, disc_statistics, cont_statistics if torch.cuda.is_available(): device = "cuda" else: device = "cpu" BATCH_SIZE = 100 NUM_WORKERS = 8 RANGE = 1 G = generator().to(device) D = disciminator().to(device) S_disc = disc_statistics().to(device) S_cont1 = cont_statistics().to(device) S_cont2 = cont_statistics().to(device) optimG = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.999)) optimD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999)) optimS = optim.Adam([{"params":S_disc.parameters()}, {"params":S_cont1.parameters()}, {"params":S_cont2.parameters()}], lr=1e-3, betas=(0.5, 0.999)) train_data = MNIST("./data", train=True, download=True, transform=ToTensor()) loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=NUM_WORKERS) label = torch.zeros(BATCH_SIZE).to(device).float() real_label = 1 fake_label = 0 c = torch.linspace(-1, 1, 10).repeat(10).reshape(-1, 1) c1 = torch.cat([c, torch.zeros_like(c)], 1).float() * RANGE c2 = torch.cat([torch.zeros_like(c), c], 1).float() * RANGE idx = torch.from_numpy(np.arange(10).repeat(10)) one_hot = torch.zeros((100, 10)).float() one_hot[range(100), idx] = 1 fix_z = torch.Tensor(100, 62).uniform_(-1, 1) fix_noise1 = torch.cat([fix_z, c1, one_hot], 1)[...,None,None].to(device) fix_noise2 = torch.cat([fix_z, c2, one_hot], 1)[...,None,None].to(device) for epoch in range(100): for i, data in enumerate(loader): # discriminator optimD.zero_grad() ## real real, _ = data real = real.to(device) label.fill_(real_label) real_prob = D(real).squeeze() real_loss = F.binary_cross_entropy_with_logits(real_prob, label) real_loss.backward() ## fake label.fill_(fake_label) ### get noise idx = torch.randint(0, 10, (BATCH_SIZE,)).long() disc_c = torch.eye(10)[idx][...,None,None].float().to(device) cont_c = torch.zeros(BATCH_SIZE, 2, 1, 1).uniform_(-1, 1).float().to(device) * RANGE z = torch.zeros(BATCH_SIZE, 62, 1, 1).uniform_(-1, 1).float().to(device) noise = torch.cat([z, cont_c, disc_c], 1).to(device).float() ### generate fake = G(noise) fake_prob = D(fake.detach()).squeeze() fake_loss = F.binary_cross_entropy_with_logits(fake_prob, label) fake_loss.backward() loss_D = real_loss + fake_loss optimD.step() # generator optimG.zero_grad() optimS.zero_grad() label.fill_(real_label) ## adversarial loss inv_fake_prob = D(fake).squeeze() inv_fake_loss = F.binary_cross_entropy_with_logits(inv_fake_prob, label) ## MINE ### c ~ P(C) idx = torch.randint(0, 10, (100,)).long() disc_c_bar = torch.eye(10)[idx].float().to(device) cont_c_bar = torch.zeros(100, 2, 1, 1).uniform_(-1, 1).float().to(device) * RANGE ### discrete variable joint_disc = S_disc(torch.cat([fake.reshape(100, -1), disc_c.reshape(100, -1)], 1)) marginal_disc = S_disc(torch.cat([fake.reshape(100, -1), disc_c_bar.reshape(100, -1)], 1)) ### continuout variable joint_cont1 = S_cont1(torch.cat([fake.reshape(100, -1), cont_c[:,0].reshape(100, -1)], 1)) joint_cont2 = S_cont2(torch.cat([fake.reshape(100, -1), cont_c[:,1].reshape(100, -1)], 1)) marginal_cont1 = S_cont1(torch.cat([fake.reshape(100, -1), cont_c_bar[:,0].reshape(100, -1)], 1)) marginal_cont2 = S_cont2(torch.cat([fake.reshape(100, -1), cont_c_bar[:,1].reshape(100, -1)], 1)) ### calc mutual information mi_disc = F.softplus(-joint_disc).mean() + F.softplus(marginal_disc).mean() mi_cont1 = F.softplus(-joint_cont1).mean() + F.softplus(marginal_cont1).mean() mi_cont2 = F.softplus(-joint_cont2).mean() + F.softplus(marginal_cont2).mean() mi = (mi_disc + mi_cont1 + mi_cont2)/3 loss = mi + inv_fake_loss loss.backward() optimG.step() optimS.step() print("epoch [{}/{}], iter [{}/{}], D : {}, G : {}, S : {}".format( epoch, 100, i, len(loader), loss_D.item(), inv_fake_loss.item(), mi.item() )) with torch.no_grad(): fake1 = G(fix_noise1) fake2 = G(fix_noise2) if not os.path.exists("results_1"): os.mkdir("results_1") vutils.save_image(fake1.detach(), "results_1/{}epoch_fake1.png".format(epoch), normalize=True, nrow=10) vutils.save_image(fake2.detach(), "results_1/{}epoch_fake2.png".format(epoch), normalize=True, nrow=10)
以下モデルファイル.
import torch.nn as nn class generator(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.ConvTranspose2d(74, 1024, 1, 1, bias=False), nn.BatchNorm2d(1024), nn.ReLU(), nn.ConvTranspose2d(1024, 128, 7, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), nn.Sigmoid() ) for m in self.modules(): if isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight, 0, 0.02) elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight, 1, 0.02) nn.init.constant_(m.bias, 0) def forward(self, x): return self.main(x) class disciminator(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.Conv2d(1, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.1), nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.1), nn.Conv2d(128, 1024, 7, bias=False), nn.BatchNorm2d(1024), nn.LeakyReLU(0.1), nn.Conv2d(1024, 1, 1, bias=False) ) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, 0, 0.02) nn.utils.spectral_norm(m, n_power_iterations=2) elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight, 1, 0.02) nn.init.constant_(m.bias, 0) def forward(self, x): return self.main(x) class cont_statistics(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.Linear(28**2 + 1, 1024, bias=False), nn.BatchNorm1d(1024), nn.LeakyReLU(0.1), nn.Linear(1024, 1024, bias=False), nn.BatchNorm1d(1024), nn.LeakyReLU(0.1), nn.Linear(1024, 1, bias=False), ) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.02) elif isinstance(m, nn.BatchNorm1d): nn.init.normal_(m.weight, 1, 0.02) nn.init.constant_(m.bias, 0) def forward(self, x): return self.main(x) class disc_statistics(nn.Module): def __init__(self): super().__init__() self.main = nn.Sequential( nn.Linear(28**2 + 10, 1024, bias=False), nn.BatchNorm1d(1024), nn.LeakyReLU(0.1), nn.Linear(1024, 1024, bias=False), nn.BatchNorm1d(1024), nn.LeakyReLU(0.1), nn.Linear(1024, 1, bias=False), ) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.02) elif isinstance(m, nn.BatchNorm1d): nn.init.normal_(m.weight, 1, 0.02) nn.init.constant_(m.bias, 0) def forward(self, x): return self.main(x)
まとめ
MINEを論文通りに実装しようとするとちょっと面倒かなという印象で逃げてしまった.というかclippingの実装気になるからちゃんとした実装上がって欲しい.
実装は何も考えずにネットワークひとつ用意するだけなのでめちゃくちゃ簡単(厳密にやるともう少し複雑だけどこれでも十分動く).相互情報量の面だけみればmin-max gameのような不安定要素も入らないのでMINEとても良い.