GAN の学習
画像生成より真贋判定の方がタスクが簡単なので、Discriminator のネットワークはシンプルで浅いものを使うのが普通。Discriminator が強すぎると Generator の学習が進まなくなる。
Patch-based adversal loss を利用する Image-to-Image Translation with Conditional Adversarial Networks が基礎的な論文で、畳み込みを使う DCGAN でネットワークを作る。
tips
- GAN はバッチサイズを大きくしてはいけない。勾配が平均化されて生成される画像の多様性がなくなる。大きくても 32 程度にしておくのがよい。
- 学習率は 2e-4 以下。Generator より Discriminator の学習率を低くすることもある
- DCGAN は Adam の β1 を 0.5 にすることを推奨している
- Generator の入力に dropout や noise を追加する
- 潜在ベクトル z に情報を持たせる(例:カテゴリや style)ことでモード崩壊しにくくなる
- 定期的にランダムな潜在ベクトルで多様性を可視化して監視する
モード崩壊(Mode Collapse)
モード崩壊(Mode Collapse)とは Generator が同じ画像ばかり出力すること。以下の対策方法がある。
学習戦略・最適化の工夫
- Minibatch Discrimination。複数のサンプル間の類似度を導入して、生成器が同じようなサンプルを出すことを抑制
- One-sided label smoothing。本物のラベル(1.0)を 0.9 にするなどして、識別器を強くしすぎない
- ノイズの追加(Input noise / instance noise)。入力画像にランダムノイズを加えて、識別器の判別を難しくしすぎない
- 識別器を強くしすぎない。学習率の調整、早めの停止など。Discriminator が強すぎると Generator が collapse しやすい。
ロス関数の変更
- WGAN。JS divergence の代わりに Earth Mover 距離を使うことで、学習が安定
- WGAN-GP。WGAN の改良。勾配クリッピングの代わりに勾配ノルムのペナルティ。
- LSGAN。二乗誤差を使うことで、識別器が飽和しにくくなる
- Relativistic GAN。「偽物が本物よりもリアルかどうか」を評価する相対的な損失関数を用いる
モデル構造の工夫
- モード追跡型の Generator(e.g. VAE-GAN, InfoGAN)。
- InfoGAN: 潜在変数に意味を持たせることで多様性を向上
- VAE-GAN: 潜在空間の連続性・多様性を保証
- 複数の Generator を使う(e.g. MGAN, MAD-GAN)。Generator を複数用意し、それぞれ異なるモードを学習させる
- Progressive Growing(PGGAN)。低解像度から徐々に高解像度にしていくことで、学習の安定性向上
正則化や多様性の強化
- Mode-seeking Regularization(Mode-seeking Loss)。潜在変数の違いに応じて出力が大きく変わるようにする正則化項
- Feature Matching。Discriminator の中間特徴を一致させるような loss を導入し、モード崩壊を抑制
- Diversity Sensitive Loss。潜在ベクトルの距離と出力の違いを関連付けるような損失項を導入
そのほか
- Unrolled GAN。Discriminator のパラメータ更新を数ステップ先まで「展開(unroll)」して Generator を学習
- PacGAN。識別器に複数の生成画像(例:2個ずつ)を同時に入力することで、多様性がないことを見抜けるようにする
- Mini-batch Standard Deviation Layer(StyleGANなどで使用)。Generator に統計量(標準偏差)を追加して多様性を強化
- スペクトル正規化
- Mixture of GANs(Discriminator や Generator を複数使う)
- Diversity Promoting Regularization
導入順
- WGAN-GP ・ LSGAN
- Minibatch Discrimination ・ Feature Matching
- InfoGAN ・ Mode-seeking loss
- Unrolled GAN ・ PacGAN
LSGAN
JS Divergenceの最小化に基づくGANでは、勾配消失が起きやすい。LSGAN はGANの損失関数を最小二乗誤差にすることで、この問題の解決する。WGAN より実装が単純なので、まずは LSGAN から試すのがよい。
LSGAN(Least Square Generative Adversarial Networks)を試してみた
Wasserstein GAN
Wasserstein 距離に基づいて loss を計算する。モード崩壊が軽減され、学習の安定化と品質の改善が期待できる。通常 Discriminator は入力が本物である確率を計算するが、wgan では Discriminator(wgan では Critic と呼ぶ)は Wasserstein 距離の近似値を計算する。Critic は入力された画像がどれだけ現実離れしているかを数値化する。
詳細な解説は今さら聞けないGAN(4) WGAN を参照。通常は Grgdient penalty を採用する。
Improved Training of Wasserstein GANs
Weight Clip を使用する古いコード例
import torch
import torch.nn as nn
import torch.optim as optim
# 損失関数の定義 (Weight Clipping を使用)
def wgan_loss(critic, generator, real_images, fake_images, clip_value):
"""
WGAN の Critic と Generator の損失を計算する関数。
Args:
critic: Critic モデル。
generator: Generator モデル。
real_images: 本物の画像データ (Tensor)。
fake_images: Generator が生成した偽の画像データ (Tensor)。
clip_value: Weight Clipping に使用する値。
Returns:
tuple: (Critic loss, Generator loss) のタプル。
"""
# Critic の損失計算
critic_real_output = critic(real_images)
critic_fake_output = critic(fake_images.detach()) # detach() で勾配計算を停止させる
critic_loss = torch.mean(critic_real_output - critic_fake_output) # Wasserstein distance の近似値を最大化
# Generator の損失計算
generator_output = critic(fake_images)
generator_loss = torch.mean(-generator_output) # Critic の評価を下げる (Wasserstein distance を最小化)
return critic_loss, generator_loss
# 簡易的な Critic と Generator の定義 (例)
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 256), # 例えば MNIST サイズの場合
nn.ReLU(),
nn.Linear(256, 1)
)
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten the image
return self.model(x)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256), # ノイズ次元は例として100
nn.ReLU(),
nn.Linear(256, 784) # MNIST サイズの場合
)
def forward(self, z):
z = z.view(z.size(0), -1) # Flatten the noise vector
return self.model(z).view(-1, 28, 28) # Reshape to image size
# ハイパーパラメータの設定 (例)
latent_dim = 100
batch_size = 64
learning_rate_critic = 0.0001
learning_rate_generator = 0.0001
clip_value = 0.01 # Weight Clipping の値
# モデルとオプティマイザの作成
critic = Critic()
generator = Generator()
optimizer_critic = optim.Adam(critic.parameters(), lr=learning_rate_critic)
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate_generator)
# ダミーデータ生成 (テスト用)
real_data = torch.randn(batch_size, 1, 28, 28) # MNIST サイズのダミーデータ
noise = torch.randn(batch_size, latent_dim)
# 損失計算と学習ループの一部
for epoch in range(10): # 例として10エポック
optimizer_critic.zero_grad()
optimizer_generator.zero_grad()
fake_data = generator(noise)
critic_loss, generator_loss = wgan_loss(critic, generator, real_data, fake_data, clip_value)
# Critic の学習 (勾配の計算と適用)
critic_loss.backward()
# Weight Clipping を適用 (重要!)
for p in critic.parameters():
torch.nn.utils.clip_grad_value_(p, -clip_value, clip_value) # 重要な制約
optimizer_critic.step()
# Generator の学習 (勾配の計算と適用)
generator_loss.backward()
optimizer_generator.step()
print(f"Epoch {epoch+1}: Critic Loss = {critic_loss:.4f}, Generator Loss = {generator_loss:.4f}")
Gradient penalty を使用するコード例
- lambda_gp: Gradient Penalty の係数。この値を調整することで、Lipschitz 制約への適合度合いを制御する。 大きすぎる値は学習を不安定にする可能性がある
- 線形補間: real_images と fake_images の間で線形補間を行い、Critic が入力データ空間全体で滑らかであるかを評価する
- 学習率: Gradient Penalty を使用した場合、従来の WGAN よりも高い学習率を使用する
import torch
import torch.nn as nn
import torch.optim as optim
# 損失関数の定義 (Gradient Penalty を使用)
def wgan_loss_gp(critic, generator, real_images, fake_images, device, lambda_gp=10):
"""
WGAN と Gradient Penalty の損失を計算する関数。
Args:
critic: Critic モデル。
generator: Generator モデル。
real_images: 本物の画像データ (Tensor)。
fake_images: Generator が生成した偽の画像データ (Tensor)。
device: 計算デバイス (CPU または GPU)。
lambda_gp: Gradient Penalty の係数。
Returns:
tuple: (Critic loss, Generator loss) のタプル。
"""
# Critic の損失計算
critic_real_output = critic(real_images)
critic_fake_output = critic(fake_images.detach()) # detach() で勾配計算を停止させる
critic_loss = torch.mean(critic_real_output - critic_fake_output)
# Gradient Penalty の計算
with torch.no_grad(): # 重要な部分: 勾配の計算をしない
duals = real_images + (torch.rand_like(real_images, device=device) - 1) * (fake_images - real_images).sum(dim=(1,2,3), keepdim=True) # 線形補間
critic_dual_output = critic(duals)
gradient_penalty = torch.mean((critic_dual_output - torch.zeros_like(critic_dual_output)).pow(2)) * lambda_gp
critic_loss += gradient_penalty
# Generator の損失計算
generator_output = critic(fake_images)
generator_loss = torch.mean(-generator_output)
return critic_loss, generator_loss
# 簡易的な Critic と Generator の定義 (例)
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 256), # 例えば MNIST サイズの場合
nn.LeakyReLU(0.2), # ReLU の代わりに Leaky ReLU を使用するのが一般的
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten the image
return self.model(x)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(100, 256), # ノイズ次元は例として100
nn.ReLU(),
nn.Linear(256, 784) # MNIST サイズの場合
)
def forward(self, z):
z = z.view(z.size(0), -1) # Flatten the noise vector
return self.model(z).view(-1, 28, 28) # Reshape to image size
# ハイパーパラメータの設定 (例)
latent_dim = 100
batch_size = 64
learning_rate_critic = 0.0001
learning_rate_generator = 0.0001
lambda_gp = 10 # Gradient Penalty の係数
# モデルとオプティマイザの作成
critic = Critic()
generator = Generator()
optimizer_critic = optim.Adam(critic.parameters(), lr=learning_rate_critic)
optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate_generator)
# ダミーデータ生成 (テスト用)
real_data = torch.randn(batch_size, 1, 28, 28).cuda() # MNIST サイズのダミーデータ
noise = torch.randn(batch_size, latent_dim).cuda() # GPU を使用する場合
# 計算デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic.to(device)
generator.to(device)
# 損失計算と学習ループの一部
for epoch in range(10): # 例として10エポック
optimizer_critic.zero_grad()
optimizer_generator.zero_grad()
fake_data = generator(noise)
critic_loss, generator_loss = wgan_loss_gp(critic, generator, real_data, fake_data, device, lambda_gp)
# Critic の学習 (勾配の計算と適用)
critic_loss.backward()
optimizer_critic.step()
# Generator の学習 (勾配の計算と適用)
generator_loss.backward()
optimizer_generator.step()
print(f"Epoch {epoch+1}: Critic Loss = {critic_loss:.4f}, Generator Loss = {generator_loss:.4f}")
StyleGAN
現在の画像生成 GAN の主流は StyleGAN で、StyleGAN によって学習の安定化と生成画像の操作が可能になった。
StyleGAN2 Analyzing and Improving the Image Quality of StyleGAN
StyleGAN3 Alias-Free Generative Adversarial Networks
Real-ESRGAN
LPIPS(VGG)loss と GAN loss と Content Loss とを損失関数として採用している。
Content Loss
\[ \begin{split} \normalsize L_{1} &= E_{x_{i}}||G(x_{i})-y||_{1} \\ {G(x_{i})} &{: 生成画像} \\ {y} &{: 本物画像} \end{split} \]pytorch-CycleGAN-and-pix2pix
安定化
Spectral Normalization (スペクトル正規化)
Discriminatorの各層の重みの Lipschitz 定数を制約する(要は Normalize する)ことで学習が安定する。Spectrally Normalized Generative Adversarial Networks (SN-GAN)。
Label Smoothing (ラベルスムージング)
Discriminatorのターゲットラベルを0または1のようなバイナリ値ではなく、0.9や0.1のような滑らかな値(小数値)に置き換える。Discriminatorが訓練データに対して過度に自信を持つことを防ぎ、過学習を抑制。これにより、Generatorに与えられる勾配信号がより適切になり、敵対的サンプルの生成リスクを低減する効果も期待できる。What is the intuition behind the Label Smoothing in GANs? - AI Stack Exchange
one_labels = torch.full((BATCH_SIZE,1),fill_value=0.9).to(device)
zero_labels = torch.full((BATCH_SIZE,1),fill_value=0.1).to(device)
One-sided Label Smoothing (片側ラベルスムージング)
従来のラベルスムージングとは異なり、Discriminatorにおいて「本物」の画像に対するラベルのみをスムージングし、「偽物」の画像に対するラベルはそのまま(通常0)にする。Label Smoothing でも Discriminator が強い場合に使う。Label smoothing should be one-sided? · Issue #10 · soumith/ganhacks - GitHub。Application of Smoothing Labels to Alleviate Overconfident of the GAN's Discriminator - IIETA
one_labels = torch.full((BATCH_SIZE,1),fill_value=0.9).to(device)
zero_labels = torch.zeros((BATCH_SIZE, 1)).to(device)
コード例
GAN の構築自体は容易だ。ResNet や VGG の最終層を取り換えて使えば、Discriminator として使える。VAE の学習ならば、Generator は VAE。GAN のシンプルな実装は【Pytorch】MNISTのGAN(敵対的生成ネットワーク)を実装する等を参照。ResNet を Discriminator に使うコードは以下のようになる。
import torch
import torch.nn as nn
import torchvision.models as models
# 訓練済みResNet-18をロード
resnet_model = models.resnet18(pretrained=True)
# 既存の最終層の入力特徴量サイズを取得
# ResNet-18の場合、fc.in_featuresは 512
num_ftrs = resnet_model.fc.in_features
# 新しい最終層を定義
# 本物/偽物の2クラス分類のため、出力は1
resnet_model.fc = nn.Linear(num_ftrs, 1)
# Sigmoidアクティベーションは、通常GANのDiscriminatorの出力では直接適用せず、
# BCEWithLogitsLossなど、ロジット値を直接受け取る損失関数と組み合わせます。
# もし0-1の確率として明示的に出力したい場合は、以下のように追加します。
# resnet_model.fc = nn.Sequential(
# nn.Linear(num_ftrs, 1),
# nn.Sigmoid()
# )
訓練コード
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# --- 1. Discriminator (ResNet) の定義 ---
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 訓練済みResNet-18をロード
resnet = models.resnet18(pretrained=True)
# 最終層をGANの Discriminator 用に変更
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 1) # 出力は1(ロジット)
self.model = resnet
def forward(self, x):
return self.model(x)
# --- 2. Generator のダミー定義 (簡単のため) ---
# 実際にはもっと複雑なネットワークになります
class Generator(nn.Module):
def __init__(self, latent_dim, img_channels, img_size):
super(Generator, self).__init__()
self.img_size = img_size
self.main = nn.Sequential(
nn.Linear(latent_dim, 256 * (img_size // 16) ** 2), # 適当な中間層
nn.LeakyReLU(0.2, inplace=True),
nn.Unflatten(1, (256, img_size // 16, img_size // 16)),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, img_channels, 4, 2, 1),
nn.Tanh() # 画像出力のためTanh
)
def forward(self, x):
return self.main(x)
# --- 3. ハイパーパラメータとデバイス設定 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
latent_dim = 100 # Generatorへの入力ノイズの次元
image_size = 64 # 生成/入力画像のサイズ
batch_size = 64
num_epochs = 50
lr = 0.0002
# --- 4. データローダー (例としてCIFAR-10) ---
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# 訓練データセットの準備
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# --- 5. モデル、オプティマイザ、損失関数の初期化 ---
discriminator = Discriminator().to(device)
generator = Generator(latent_dim, 3, image_size).to(device) # CIFAR-10は3チャンネル
# オプティマイザ
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
# 損失関数 (Sigmoidと交差エントロピーを組み合わせたもの)
criterion = nn.BCEWithLogitsLoss()
# 真偽ラベル
real_label = 1.0
fake_label = 0.0
# --- 6. 訓練ループ ---
print("Starting Training Loop...")
for epoch in range(num_epochs):
for i, data in enumerate(train_loader, 0):
# --- Discriminatorの訓練 ---
discriminator.zero_grad()
real_images = data[0].to(device)
b_size = real_images.size(0)
# 1. 本物画像を Discriminator に入力
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
output = discriminator(real_images).view(-1)
errD_real = criterion(output, label)
errD_real.backward()
D_x = torch.sigmoid(output).mean().item() # Sigmoidで0-1に変換して表示
# 2. 偽物画像を Generator で生成し、Discriminator に入力
noise = torch.randn(b_size, latent_dim, device=device)
fake_images = generator(noise)
label.fill_(fake_label)
output = discriminator(fake_images.detach()).view(-1) # detach()でGeneratorへの勾配計算を停止
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = torch.sigmoid(output).mean().item()
errD = errD_real + errD_fake
optimizer_D.step()
# --- Generatorの訓練 ---
generator.zero_grad()
label.fill_(real_label) # Generatorは偽物画像を本物とDに誤認させたい
output = discriminator(fake_images).view(-1)
errG = criterion(output, label)
errG.backward()
D_G_z2 = torch.sigmoid(output).mean().item()
optimizer_G.step()
if i % 100 == 0:
print(f'[{epoch}/{num_epochs}][{i}/{len(train_loader)}] '
f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
# エポックごとにモデルを保存するなどの処理を追加することもできます
# if (epoch + 1) % 10 == 0:
# torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')
# torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')
print("Training finished!")