dskjal
広告
広告

GAN の学習

カテゴリ:deeplearning

画像生成より真贋判定の方がタスクが簡単なので、Discriminator のネットワークはシンプルで浅いものを使うのが普通。Discriminator が強すぎると Generator の学習が進まなくなる。

Patch-based adversal loss を利用する Image-to-Image Translation with Conditional Adversarial Networks が基礎的な論文で、畳み込みを使う DCGAN でネットワークを作る。

tips

モード崩壊(Mode Collapse)

モード崩壊(Mode Collapse)とは Generator が同じ画像ばかり出力すること。以下の対策方法がある。

学習戦略・最適化の工夫
ロス関数の変更
モデル構造の工夫
正則化や多様性の強化
そのほか

導入順

  1. WGAN-GP ・ LSGAN
  2. Minibatch Discrimination ・ Feature Matching
  3. InfoGAN ・ Mode-seeking loss
  4. Unrolled GAN ・ PacGAN

LSGAN

JS Divergenceの最小化に基づくGANでは、勾配消失が起きやすい。LSGAN はGANの損失関数を最小二乗誤差にすることで、この問題の解決する。WGAN より実装が単純なので、まずは LSGAN から試すのがよい。

LSGANはなぜ学習が安定するのか?

LSGAN(Least Square Generative Adversarial Networks)を試してみた

Wasserstein GAN

Wasserstein 距離に基づいて loss を計算する。モード崩壊が軽減され、学習の安定化と品質の改善が期待できる。通常 Discriminator は入力が本物である確率を計算するが、wgan では Discriminator(wgan では Critic と呼ぶ)は Wasserstein 距離の近似値を計算する。Critic は入力された画像がどれだけ現実離れしているかを数値化する。

詳細な解説は今さら聞けないGAN(4) WGAN を参照。通常は Grgdient penalty を採用する。

Improved Training of Wasserstein GANs

Weight Clip を使用する古いコード例
import torch
import torch.nn as nn
import torch.optim as optim


# 損失関数の定義 (Weight Clipping を使用)
def wgan_loss(critic, generator, real_images, fake_images, clip_value):
    """
    WGAN の Critic と Generator の損失を計算する関数。

    Args:
        critic: Critic モデル。
        generator: Generator モデル。
        real_images: 本物の画像データ (Tensor)。
        fake_images: Generator が生成した偽の画像データ (Tensor)。
        clip_value: Weight Clipping に使用する値。

    Returns:
        tuple: (Critic loss, Generator loss) のタプル。
    """

    # Critic の損失計算
    critic_real_output = critic(real_images)
    critic_fake_output = critic(fake_images.detach())  # detach() で勾配計算を停止させる

    critic_loss = torch.mean(critic_real_output - critic_fake_output)  # Wasserstein distance の近似値を最大化

    # Generator の損失計算
    generator_output = critic(fake_images)
    generator_loss = torch.mean(-generator_output) # Critic の評価を下げる (Wasserstein distance を最小化)

    return critic_loss, generator_loss



# 簡易的な Critic と Generator の定義 (例)
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256), # 例えば MNIST サイズの場合
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the image
        return self.model(x)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256), # ノイズ次元は例として100
            nn.ReLU(),
            nn.Linear(256, 784)  # MNIST サイズの場合
        )

    def forward(self, z):
        z = z.view(z.size(0), -1) # Flatten the noise vector
        return self.model(z).view(-1, 28, 28) # Reshape to image size



# ハイパーパラメータの設定 (例)
latent_dim = 100
batch_size = 64
learning_rate_critic = 0.0001
learning_rate_generator = 0.0001
clip_value = 0.01  # Weight Clipping の値

# モデルとオプティマイザの作成
critic = Critic()
generator = Generator()

optimizer_critic = optim.Adam(critic.parameters(), lr=learning_rate_critic)
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate_generator)




# ダミーデータ生成 (テスト用)
real_data = torch.randn(batch_size, 1, 28, 28)  # MNIST サイズのダミーデータ
noise = torch.randn(batch_size, latent_dim)



# 損失計算と学習ループの一部
for epoch in range(10):  # 例として10エポック
    optimizer_critic.zero_grad()
    optimizer_generator.zero_grad()

    fake_data = generator(noise)

    critic_loss, generator_loss = wgan_loss(critic, generator, real_data, fake_data, clip_value)


    # Critic の学習 (勾配の計算と適用)
    critic_loss.backward()

    # Weight Clipping を適用 (重要!)
    for p in critic.parameters():
        torch.nn.utils.clip_grad_value_(p, -clip_value, clip_value) # 重要な制約

    optimizer_critic.step()

    # Generator の学習 (勾配の計算と適用)
    generator_loss.backward()
    optimizer_generator.step()

    print(f"Epoch {epoch+1}: Critic Loss = {critic_loss:.4f}, Generator Loss = {generator_loss:.4f}")
Gradient penalty を使用するコード例
  • lambda_gp: Gradient Penalty の係数。この値を調整することで、Lipschitz 制約への適合度合いを制御する。 大きすぎる値は学習を不安定にする可能性がある
  • 線形補間: real_images と fake_images の間で線形補間を行い、Critic が入力データ空間全体で滑らかであるかを評価する
  • 学習率: Gradient Penalty を使用した場合、従来の WGAN よりも高い学習率を使用する
import torch
import torch.nn as nn
import torch.optim as optim

# 損失関数の定義 (Gradient Penalty を使用)
def wgan_loss_gp(critic, generator, real_images, fake_images, device, lambda_gp=10):
    """
    WGAN と Gradient Penalty の損失を計算する関数。

    Args:
        critic: Critic モデル。
        generator: Generator モデル。
        real_images: 本物の画像データ (Tensor)。
        fake_images: Generator が生成した偽の画像データ (Tensor)。
        device: 計算デバイス (CPU または GPU)。
        lambda_gp: Gradient Penalty の係数。

    Returns:
        tuple: (Critic loss, Generator loss) のタプル。
    """

    # Critic の損失計算
    critic_real_output = critic(real_images)
    critic_fake_output = critic(fake_images.detach())  # detach() で勾配計算を停止させる
    critic_loss = torch.mean(critic_real_output - critic_fake_output)

    # Gradient Penalty の計算
    with torch.no_grad(): # 重要な部分: 勾配の計算をしない
        duals = real_images + (torch.rand_like(real_images, device=device) - 1) * (fake_images - real_images).sum(dim=(1,2,3), keepdim=True) # 線形補間
        critic_dual_output = critic(duals)

    gradient_penalty = torch.mean((critic_dual_output - torch.zeros_like(critic_dual_output)).pow(2)) * lambda_gp
    critic_loss += gradient_penalty


    # Generator の損失計算
    generator_output = critic(fake_images)
    generator_loss = torch.mean(-generator_output)

    return critic_loss, generator_loss



# 簡易的な Critic と Generator の定義 (例)
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 256), # 例えば MNIST サイズの場合
            nn.LeakyReLU(0.2),  # ReLU の代わりに Leaky ReLU を使用するのが一般的
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the image
        return self.model(x)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256), # ノイズ次元は例として100
            nn.ReLU(),
            nn.Linear(256, 784)  # MNIST サイズの場合
        )

    def forward(self, z):
        z = z.view(z.size(0), -1) # Flatten the noise vector
        return self.model(z).view(-1, 28, 28) # Reshape to image size



# ハイパーパラメータの設定 (例)
latent_dim = 100
batch_size = 64
learning_rate_critic = 0.0001
learning_rate_generator = 0.0001
lambda_gp = 10  # Gradient Penalty の係数

# モデルとオプティマイザの作成
critic = Critic()
generator = Generator()

optimizer_critic = optim.Adam(critic.parameters(), lr=learning_rate_critic)
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate_generator)



# ダミーデータ生成 (テスト用)
real_data = torch.randn(batch_size, 1, 28, 28).cuda()  # MNIST サイズのダミーデータ
noise = torch.randn(batch_size, latent_dim).cuda() # GPU を使用する場合

# 計算デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic.to(device)
generator.to(device)


# 損失計算と学習ループの一部
for epoch in range(10):  # 例として10エポック
    optimizer_critic.zero_grad()
    optimizer_generator.zero_grad()

    fake_data = generator(noise)

    critic_loss, generator_loss = wgan_loss_gp(critic, generator, real_data, fake_data, device, lambda_gp)


    # Critic の学習 (勾配の計算と適用)
    critic_loss.backward()
    optimizer_critic.step()

    # Generator の学習 (勾配の計算と適用)
    generator_loss.backward()
    optimizer_generator.step()

    print(f"Epoch {epoch+1}: Critic Loss = {critic_loss:.4f}, Generator Loss = {generator_loss:.4f}")

StyleGAN

現在の画像生成 GAN の主流は StyleGAN で、StyleGAN によって学習の安定化と生成画像の操作が可能になった

StyleGAN2 Analyzing and Improving the Image Quality of StyleGAN

StyleGAN3 Alias-Free Generative Adversarial Networks

Real-ESRGAN

LPIPS(VGG)loss と GAN loss と Content Loss とを損失関数として採用している。

Content Loss

\[ \begin{split} \normalsize L_{1} &= E_{x_{i}}||G(x_{i})-y||_{1} \\ {G(x_{i})} &{: 生成画像} \\ {y} &{: 本物画像} \end{split} \]

pytorch-CycleGAN-and-pix2pix

安定化

Spectral Normalization (スペクトル正規化)

Discriminatorの各層の重みの Lipschitz 定数を制約する(要は Normalize する)ことで学習が安定する。Spectrally Normalized Generative Adversarial Networks (SN-GAN)

Label Smoothing (ラベルスムージング)

Discriminatorのターゲットラベルを0または1のようなバイナリ値ではなく、0.9や0.1のような滑らかな値(小数値)に置き換える。Discriminatorが訓練データに対して過度に自信を持つことを防ぎ、過学習を抑制。これにより、Generatorに与えられる勾配信号がより適切になり、敵対的サンプルの生成リスクを低減する効果も期待できる。What is the intuition behind the Label Smoothing in GANs? - AI Stack Exchange

one_labels = torch.full((BATCH_SIZE,1),fill_value=0.9).to(device)
zero_labels = torch.full((BATCH_SIZE,1),fill_value=0.1).to(device)

One-sided Label Smoothing (片側ラベルスムージング)

従来のラベルスムージングとは異なり、Discriminatorにおいて「本物」の画像に対するラベルのみをスムージングし、「偽物」の画像に対するラベルはそのまま(通常0)にする。Label Smoothing でも Discriminator が強い場合に使う。Label smoothing should be one-sided? · Issue #10 · soumith/ganhacks - GitHubApplication of Smoothing Labels to Alleviate Overconfident of the GAN's Discriminator - IIETA

one_labels = torch.full((BATCH_SIZE,1),fill_value=0.9).to(device)
zero_labels = torch.zeros((BATCH_SIZE, 1)).to(device)

コード例

GAN の構築自体は容易だ。ResNet や VGG の最終層を取り換えて使えば、Discriminator として使える。VAE の学習ならば、Generator は VAE。GAN のシンプルな実装は【Pytorch】MNISTのGAN(敵対的生成ネットワーク)を実装する等を参照。ResNet を Discriminator に使うコードは以下のようになる。

import torch
import torch.nn as nn
import torchvision.models as models

# 訓練済みResNet-18をロード
resnet_model = models.resnet18(pretrained=True)

# 既存の最終層の入力特徴量サイズを取得
# ResNet-18の場合、fc.in_featuresは 512
num_ftrs = resnet_model.fc.in_features

# 新しい最終層を定義
# 本物/偽物の2クラス分類のため、出力は1
resnet_model.fc = nn.Linear(num_ftrs, 1)

# Sigmoidアクティベーションは、通常GANのDiscriminatorの出力では直接適用せず、
# BCEWithLogitsLossなど、ロジット値を直接受け取る損失関数と組み合わせます。
# もし0-1の確率として明示的に出力したい場合は、以下のように追加します。
# resnet_model.fc = nn.Sequential(
#     nn.Linear(num_ftrs, 1),
#     nn.Sigmoid()
# )
訓練コード
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- 1. Discriminator (ResNet) の定義 ---
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # 訓練済みResNet-18をロード
        resnet = models.resnet18(pretrained=True)
        # 最終層をGANの Discriminator 用に変更
        num_ftrs = resnet.fc.in_features
        resnet.fc = nn.Linear(num_ftrs, 1) # 出力は1(ロジット)
        self.model = resnet

    def forward(self, x):
        return self.model(x)

# --- 2. Generator のダミー定義 (簡単のため) ---
# 実際にはもっと複雑なネットワークになります
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, img_size):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.main = nn.Sequential(
            nn.Linear(latent_dim, 256 * (img_size // 16) ** 2), # 適当な中間層
            nn.LeakyReLU(0.2, inplace=True),
            nn.Unflatten(1, (256, img_size // 16, img_size // 16)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),
            nn.Tanh() # 画像出力のためTanh
        )

    def forward(self, x):
        return self.main(x)


# --- 3. ハイパーパラメータとデバイス設定 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

latent_dim = 100 # Generatorへの入力ノイズの次元
image_size = 64 # 生成/入力画像のサイズ
batch_size = 64
num_epochs = 50
lr = 0.0002

# --- 4. データローダー (例としてCIFAR-10) ---
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# 訓練データセットの準備
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# --- 5. モデル、オプティマイザ、損失関数の初期化 ---
discriminator = Discriminator().to(device)
generator = Generator(latent_dim, 3, image_size).to(device) # CIFAR-10は3チャンネル

# オプティマイザ
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

# 損失関数 (Sigmoidと交差エントロピーを組み合わせたもの)
criterion = nn.BCEWithLogitsLoss()

# 真偽ラベル
real_label = 1.0
fake_label = 0.0

# --- 6. 訓練ループ ---
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(train_loader, 0):
        # --- Discriminatorの訓練 ---
        discriminator.zero_grad()
        real_images = data[0].to(device)
        b_size = real_images.size(0)

        # 1. 本物画像を Discriminator に入力
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = discriminator(real_images).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = torch.sigmoid(output).mean().item() # Sigmoidで0-1に変換して表示

        # 2. 偽物画像を Generator で生成し、Discriminator に入力
        noise = torch.randn(b_size, latent_dim, device=device)
        fake_images = generator(noise)
        label.fill_(fake_label)
        output = discriminator(fake_images.detach()).view(-1) # detach()でGeneratorへの勾配計算を停止
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = torch.sigmoid(output).mean().item()

        errD = errD_real + errD_fake
        optimizer_D.step()

        # --- Generatorの訓練 ---
        generator.zero_grad()
        label.fill_(real_label) # Generatorは偽物画像を本物とDに誤認させたい
        output = discriminator(fake_images).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = torch.sigmoid(output).mean().item()
        optimizer_G.step()

        if i % 100 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(train_loader)}] '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

    # エポックごとにモデルを保存するなどの処理を追加することもできます
    # if (epoch + 1) % 10 == 0:
    #     torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')
    #     torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')

print("Training finished!")

広告
広告

カテゴリ