機械学習とかコンピュータビジョンとか

CVやMLに関する勉強のメモ書き。

MINEを使ってinfoGANを実装した

はじめに

Mutual information neural estimation(MINE)を使ってinfoGANを実装したのでメモ.MINEに関してのメモはこちら

設定

オリジナルのinfoGANと同様にMNISTで離散変数1つ,連続変数2つでgeneratorを学習した.infoGANの特徴である相互情報量最大化の部分をMINEに置き換えた.MINEに関しては雰囲気を知りたかったのでだいぶ簡略化した実装にした. 具体的にはMINEのオリジナルの論文ではKLDが無限に大きくなってしまい勾配のバランスが崩れるため,adaptive gradient clippingを使っていたが,今回はKLDではなくDIM(DIMのメモはこちら)で使われていたJensen-Shannonダイバージェンスバージョンを使ってclippingの実装は省いた.また,statistics networkとgeneratorは別々に最適化するところを,これもDIMと同様同時に最適化するようにした.ただし,DIMではgeneratorとstatistics networkの一部を共有していたが完全に別々にした.

モデルはgeneratorとdiscriminator,3つの変数それぞれに対して相互情報量を計算するためのstatistics networkを用意.目的関数は以下のように定義した.

\displaystyle
\max_{G, S_1, S_2, S_3}\min_D\mathbb{E}_x[\log D(X)]+\mathbb{E}_z[\log(1-D(G(z))]+\frac{1}{3}\left(\mathcal{I}_{S_1}(G(X),c_1)+\mathcal{I}_{S_2}(G(X),c_2)+\mathcal{I}_{S_3}(G(X),c_3)\right)\\ \displaystyle
\mathcal{I}_{S}(X,c)=\mathbb{E}_{P(X,c)}[-softplus(-S(X,c))]-\mathbb{E}_{P(X)P(c)}[softplus(S(X,c'))]

相互情報量は平均をとった.こうしないと相互情報量のロスに引っ張られて生成画像があんまり綺麗にならなかった.

モデルの構成はその辺に落ちてたinfoGANの実装のモデルを真似た.statistics networkに関してはとりあえず画像をベクトル化して潜在変数とくっつけたものを入力とするMLPとして実装した.後一応,discriminatorにはspectral normalizationを付けてる(なくても多分動くんじゃないかな).

離散変数はonehotベクトルで連続変数とノイズは-1~1の範囲の一様分布からサンプリングした.

実験結果

100エポック学習した結果.100エポックもいらなかった.

離散変数と連続変数の片方を-1から1で動かした場合.

f:id:peluigi:20180829120208p:plain

離散変数と連続変数のもう一方を-1から1で動かした場合.

f:id:peluigi:20180829120216p:plain

思った以上にうまく動いてびっくり.ただ,3と5の区別が微妙な感じなのと文字の太さはあんまり表現できてない.もう少しちゃんと実装すればもっとうまくいくのかどうか. とりあえず現状の実装だと相互情報量がサチってたからこれ以上は無理な気がする.

コード

適当に転がってたGANのコードを改変して作ったのでいろんなところがハードコードされているけどとりあえず動くはず. 以下メインファイル.

import torch
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import os

import torchvision.utils as vutils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

from model import generator, disciminator, disc_statistics, cont_statistics

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

BATCH_SIZE = 100
NUM_WORKERS = 8
RANGE = 1

G = generator().to(device)
D = disciminator().to(device)
S_disc = disc_statistics().to(device)
S_cont1 = cont_statistics().to(device)
S_cont2 = cont_statistics().to(device)

optimG = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.999))
optimD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimS = optim.Adam([{"params":S_disc.parameters()}, {"params":S_cont1.parameters()}, {"params":S_cont2.parameters()}], lr=1e-3, betas=(0.5, 0.999))

train_data = MNIST("./data", train=True, download=True, transform=ToTensor())
loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=NUM_WORKERS)

label = torch.zeros(BATCH_SIZE).to(device).float()
real_label = 1
fake_label = 0

c = torch.linspace(-1, 1, 10).repeat(10).reshape(-1, 1)
c1 = torch.cat([c, torch.zeros_like(c)], 1).float() * RANGE
c2 = torch.cat([torch.zeros_like(c), c], 1).float() * RANGE

idx = torch.from_numpy(np.arange(10).repeat(10))
one_hot = torch.zeros((100, 10)).float()
one_hot[range(100), idx] = 1
fix_z = torch.Tensor(100, 62).uniform_(-1, 1)

fix_noise1 = torch.cat([fix_z, c1, one_hot], 1)[...,None,None].to(device)
fix_noise2 = torch.cat([fix_z, c2, one_hot], 1)[...,None,None].to(device)

for epoch in range(100):
    for i, data in enumerate(loader):
        # discriminator
        optimD.zero_grad()
        ## real
        real, _ = data
        real = real.to(device)
        label.fill_(real_label)
        real_prob = D(real).squeeze()
        real_loss = F.binary_cross_entropy_with_logits(real_prob, label)
        real_loss.backward()
        ## fake
        label.fill_(fake_label)
        ### get noise
        idx = torch.randint(0, 10, (BATCH_SIZE,)).long()
        disc_c = torch.eye(10)[idx][...,None,None].float().to(device)
        cont_c = torch.zeros(BATCH_SIZE, 2, 1, 1).uniform_(-1, 1).float().to(device) * RANGE
        z = torch.zeros(BATCH_SIZE, 62, 1, 1).uniform_(-1, 1).float().to(device)
        noise = torch.cat([z, cont_c, disc_c], 1).to(device).float()
        ### generate
        fake = G(noise)

        fake_prob = D(fake.detach()).squeeze()
        fake_loss = F.binary_cross_entropy_with_logits(fake_prob, label)
        fake_loss.backward()
        loss_D = real_loss + fake_loss
        optimD.step()

        # generator
        optimG.zero_grad()
        optimS.zero_grad()
        label.fill_(real_label)
        ## adversarial loss
        inv_fake_prob = D(fake).squeeze()
        inv_fake_loss = F.binary_cross_entropy_with_logits(inv_fake_prob, label)
        ## MINE
        ### c ~ P(C)
        idx = torch.randint(0, 10, (100,)).long()
        disc_c_bar = torch.eye(10)[idx].float().to(device)
        cont_c_bar = torch.zeros(100, 2, 1, 1).uniform_(-1, 1).float().to(device) * RANGE
        ### discrete variable
        joint_disc = S_disc(torch.cat([fake.reshape(100, -1), disc_c.reshape(100, -1)], 1))
        marginal_disc = S_disc(torch.cat([fake.reshape(100, -1), disc_c_bar.reshape(100, -1)], 1))
        ### continuout variable
        joint_cont1 = S_cont1(torch.cat([fake.reshape(100, -1), cont_c[:,0].reshape(100, -1)], 1))
        joint_cont2 = S_cont2(torch.cat([fake.reshape(100, -1), cont_c[:,1].reshape(100, -1)], 1))
        marginal_cont1 = S_cont1(torch.cat([fake.reshape(100, -1), cont_c_bar[:,0].reshape(100, -1)], 1))
        marginal_cont2 = S_cont2(torch.cat([fake.reshape(100, -1), cont_c_bar[:,1].reshape(100, -1)], 1))
        ### calc mutual information
        mi_disc = F.softplus(-joint_disc).mean() + F.softplus(marginal_disc).mean()
        mi_cont1 = F.softplus(-joint_cont1).mean() + F.softplus(marginal_cont1).mean()
        mi_cont2 = F.softplus(-joint_cont2).mean() + F.softplus(marginal_cont2).mean()
        mi = (mi_disc + mi_cont1 + mi_cont2)/3

        loss = mi + inv_fake_loss
        loss.backward()
        optimG.step()
        optimS.step()
        print("epoch [{}/{}], iter [{}/{}], D : {}, G : {}, S : {}".format(
            epoch, 100, i, len(loader), loss_D.item(), inv_fake_loss.item(), mi.item()
        ))
    with torch.no_grad():
        fake1 = G(fix_noise1)
        fake2 = G(fix_noise2)
        if not os.path.exists("results_1"):
            os.mkdir("results_1")
        vutils.save_image(fake1.detach(),
                    "results_1/{}epoch_fake1.png".format(epoch),
                    normalize=True, nrow=10)
        vutils.save_image(fake2.detach(),
                    "results_1/{}epoch_fake2.png".format(epoch),
                    normalize=True, nrow=10)

以下モデルファイル.

import torch.nn as nn

class generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(74, 1024, 1, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 128, 7, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Sigmoid()
        )

        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.main(x)

class disciminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 1024, 7, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1),
            nn.Conv2d(1024, 1, 1, bias=False)
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.02)
                nn.utils.spectral_norm(m, n_power_iterations=2)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.main(x)

class cont_statistics(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(28**2 + 1, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 1, bias=False),
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.02)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.normal_(m.weight, 1, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.main(x)


class disc_statistics(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(28**2 + 10, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, 1, bias=False),
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.02)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.normal_(m.weight, 1, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.main(x)

まとめ

MINEを論文通りに実装しようとするとちょっと面倒かなという印象で逃げてしまった.というかclippingの実装気になるからちゃんとした実装上がって欲しい.

実装は何も考えずにネットワークひとつ用意するだけなのでめちゃくちゃ簡単(厳密にやるともう少し複雑だけどこれでも十分動く).相互情報量の面だけみればmin-max gameのような不安定要素も入らないのでMINEとても良い.