dskjal
広告
広告

NVIDIA の AMP が遅くなる理由

カテゴリ:deeplearning

この記事の要約

赤で囲った部分が AMP で高速化される<br>出典:TENSOR CORE DL PERFORMANCE GUIDE. Michael Andersch et al. p. 26. https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf

赤で囲った部分が AMP で高速化される
出典:TENSOR CORE DL PERFORMANCE GUIDE. Michael Andersch et al. p. 26. https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf

NVIDIA の GPU の Tensor コアはカタログスペック上は、FP32 比で 10~30 倍の TFLOPS が記載されているが、実際は2倍程度しか速くならない。Tensor コアの実効 TFLOPS は、最適化してもカタログスペックの 30 %が限界だ。以下の要素がその原因だ。

メモリアクセスがボトルネックになっている

対策

AMPが強制的にFP32へ昇格させる演算

LayerNorm/Softmax/Sum/各種 loss 関数/sin, cos, tan, log などは Tensor コアが使われず、追加のカーネル呼び出しが発生するために遅い。

対策

Tensor Core が利用されない形状やメモリレイアウト

対策
畳み込み

カーネル起動回数増加とスケジューリング遅延が重なり、演算密度(Arithmetic Intensity)が理論前提を下回る

profiler で計測して、「kernel 数 >> 層数」の場合はカーネル起動過多。

対策

DataLoader 待ち

DataLoader の pin_memory や num_workers を設定しないと遅い。

プロファイラ

from torch.profiler import profile, ProfilerActivity, record_function

with profile(activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    for step, data in enumerate(train_loader):
        with record_funtion('inference'):
            y = model(data)
        loss = criterion(y, data)
        with record_function('backward'):
            loss.backward()
        with record_function('step'):
            optimizer.step()

print(prof.key_averages().table(sort_by = "cuda_time", row_limit = 10))

PyTorch Profiler With TensorBoard

AMP の使い方

torch.cuda.amp ではなく torch.amp を使う。

from torch.amp import autocast, GradScaler

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = GradScaler('cuda')
# 学習ループ
for data, target in dataloader:
    optimizer.zero_grad()

    with autocast('cuda'):
        # autocastコンテキスト内で計算
        output = model(data)
        loss = nn.functional.cross_entropy(output, target)

    # スケーラーで勾配をスケーリング
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

参考文献

Data Movement Is All You Need: A Case Study on Optimizing Transformers

NVIDIA Deep Learning Performance

NVIDIA Deep Learning Performance Optimizing Performance

Tips for Optimizing GPU Performance Using Tensor Cores

TENSOR CORE DL PERFORMANCE GUIDE

Pytorch Performance Tuning Guide

Tempo: Accelerating Transformer-Based Model Training through Memory Footprint Reduction

Attention Mechanism for Transformer Models with CUDA

Dynamic Stashing Quantization for Efficient Transformer Training

Transformer Engine ではじめる FP8 Training (導入編)

Transformer Engine Performance Optimizations

MIXED PRECISION TRAINING

Fusion

Fusion は複数の演算処理を統合して単一の高効率なカーネルとして実行する。以下の効果がある。

Fusion戦略の分類

種類対象演算効果実装
VerticalLinear→GELU→Addカーネル起動削減torch.compile, TorchScript
Horizontal並列Linear演算メモリ帯域効率向上torch.compile
AttentionQKV投影→Attention→Outputメモリアクセス最適化Flash Attention, Transformer Engine
LayerNormAdd→LayerNorm→Linear中間結果削減Transformer Engine, カスタム実装
EpilogueGEMM→Bias→Activation レジスタ再利用torch.compile, Triton

実装

1. TorchScript + JIT Compilation(torch.compile の使用を推奨)

コードサンプル
import torch

# Fusionが適用される例
@torch.jit.script
def fused_attention_ffn(x, weight_q, weight_k, weight_v, ffn_weight):
    # Attention計算
    q = torch.matmul(x, weight_q)
    k = torch.matmul(x, weight_k)
    v = torch.matmul(x, weight_v)
    
    # Scaled Dot-Product Attention
    scale = 1.0 / (q.size(-1) ** 0.5)
    scores = torch.matmul(q * scale, k.transpose(-2, -1))
    attn_weights = torch.softmax(scores, dim=-1)
    attn_output = torch.matmul(attn_weights, v)
    
    # FFN with GELU fusion
    ffn_output = torch.nn.functional.gelu(torch.matmul(attn_output, ffn_weight))
    return ffn_output

# JITコンパイル
compiled_model = torch.jit.script(fused_attention_ffn)

torch.compile は以下の最適化を自動適用する。

参考文献

Optimizing CUDA Recurrent Neural Networks with TorchScript

Optimizing models using the PyTorch JIT

2. torch.compile(PyTorch 2.0+)

triton が必要になる。Windows の場合は triton-windows からインストール。

よくあるミスは Simple Workarounds を参照。

torch.compile を使う最大の利点はメモリリークを指摘してくれることだ。forward の中で torch.arrange, torch.cat, torch.Tensor([...]), tensor.clone(), torch.randn_like, torch.rand を呼び出したり、キャッシュ変数に self.register_buffer(..., persistent=False) を適用していなかったりすると警告(will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy)がでる。この警告は optimizer.zero_grad(set_to_none=True) にすると再現できる。

Windows にインストール

バージョンを調べる

python -m pip index versions triton-windows

インストール

python -m pip install triton-windows==3.3.1.post19

MSVC も必要になる。詳細は How to use torch.compile on Windows を参照。vcvars64.bat は以下のディレクトリの可能性がある。

C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvars64.bat
サンプルコード
import torch
model = torch.compile(model, mode='max-autotune')  # 最大Fusion適用
動的バッチサイズやシーケンス長に対応
model = torch.compile(
    model, 
    mode='max-autotune',
    dynamic=True,  # 動的形状サポート
    fullgraph=True  # グラフ全体の最適化
)

torch._dynamo.decorators.mark_static_address の使い方は decorators.py を参照

パラメータのマッピング

"L['self'].param_groups[0]['params'][0].grad" のようなメッセージは分かりにくいので、モデルロード後にマッピングを取得する。

param_list = list(model.named_parameters())

for i, param in enumerate(optimizer.param_groups[0]['params']):
    for name, p in param_list:
        if param is p:
            print(f"[{i}] {name}")
cl error
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Compiler: cl is not found.

上記のエラーは環境変数に以下のパスを追加する。追加したらシェルや VS Code を再起動して、「cl」でパスが通っているか確認する。

PATH

C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\bin\Hostx64\x64\

INCLUDE

C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\
C:\Program Files (x86)\Windows Kits\10\Include\10.0.26100.0\ucrt\
C:\Program Files (x86)\Windows Kits\10\Include\10.0.26100.0\shared\

LIB

C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\lib\
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\lib\x64\store\
C:\Program Files (x86)\Windows Kits\10\Lib\10.0.26100.0\um\x64\
C:\Program Files (x86)\Windows Kits\10\Lib\10.0.26100.0\ucrt\x64\
C:\Users\[ユーザー名]\AppData\Local\Programs\Python\Python[バージョン]\libs\

未解決の外部シンボルは venv/lib/site-packages/torch/_inductor/cpp_builder.py の 787 行目に以下を追加。

libraries.append("kernel32")
libraries.append("runtimeobject")
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x83 in position 135: invalid start byte

シェルのロケールが cp932 になっていることが原因。venv/lib/site-packages/torch/_inductor/cpp_builder.py の 63 行目を以下のように編集する。

import locale
SUBPROCESS_DECODE_ARGS = (locale.getpreferredencoding(),) if _IS_WINDOWS else ()

そのほか
torch.set_float32_matmul_precision('high')
import os
os.environ["TRITON_LOG_LEVEL"] = "0"

os.environ["TRITON_LOG_LEVEL"] = "0" は以下の警告を非表示にする。他の警告も非表示にされる恐れがあるので、デバッグ終了後に設定する。以下の警告は SM 数が 80 以下の GPU で実行していると出る。

W0802 04:00:15.720576 21708 Lib\site-packages\torch\_inductor\utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode
参考文献

Accelerating PyTorch Models: Inside torch.compile’s Kernel Optimization

Accelerated PyTorch inference with torch.compile on AWS Graviton processors

3. Flash Attention統合

torch.nn.functional.scaled_dot_product_attention を使う。

torch.nn.MultiheadAttention も参照。

コードサンプル
import torch
import torch.nn.functional as F

# Flash Attention使用例
def flash_attention_block(q, k, v, causal=True):
    # PyTorch 2.0以降ではScaled Dot Product Attentionが自動的にFlash Attentionを利用
    return F.scaled_dot_product_attention(
        q, k, v, 
        is_causal=causal,
        enable_gqa=False
    )

class OptimizedTransformer(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention_proj = torch.nn.Linear(config.d_model, 3 * config.d_model)
        self.output_proj = torch.nn.Linear(config.d_model, config.d_model)
        
    @torch.compile
    def forward(self, x):
        B, L, D = x.shape
        qkv = self.attention_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # Flash Attentionが自動適用される
        attn_out = flash_attention_block(q, k, v)
        return self.output_proj(attn_out)
参考文献

FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention

4. Transformer Engine(NVIDIA)

コードサンプル
import transformer_engine.pytorch as te

class FusedTransformerLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_attention_heads):
        super().__init__()
        self.attention = te.MultiheadAttention(
            hidden_size, 
            num_attention_heads,
            fuse_qkv_params=True  # QKV投影統合
        )
        self.mlp = te.LayerNormMLP(
            hidden_size,
            4 * hidden_size,
            bias=True,
            activation="gelu",
            normalization="LayerNorm"  # LayerNorm + MLP統合
        )
    
    def forward(self, x):
        # 内部で最適化されたFusionカーネル使用
        attn_out = self.attention(x)
        return self.mlp(attn_out)

# FP8混合精度Fusion適用
with te.fp8_autocast(enabled=True):
    model = FusedTransformerLayer(768, 12)
    output = model(input_tensor)

5. カスタムTritonカーネル

コードサンプル
import triton
import triton.language as tl

@triton.jit
def fused_attention_kernel(
    Q, K, V, Out, 
    stride_qm, stride_qh, stride_qk,
    stride_km, stride_kh, stride_kk,
    stride_vm, stride_vh, stride_vk,
    stride_om, stride_oh, stride_ok,
    M, N, H, K,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
    """Fused Attention + LayerNorm + FFN カーネル"""
    # プログラムID取得
    pid_m = tl.program_id(0)
    pid_h = tl.program_id(1)
    
    # ブロック範囲計算
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    
    # Query, Key, Value読み込み
    q_ptrs = Q + pid_h * stride_qh + offs_m[:, None] * stride_qm + offs_n[None, :] * stride_qk
    k_ptrs = K + pid_h * stride_kh + offs_n[:, None] * stride_km + offs_n[None, :] * stride_kk
    v_ptrs = V + pid_h * stride_vh + offs_n[:, None] * stride_vm + offs_n[None, :] * stride_vk
    
    q = tl.load(q_ptrs, mask=offs_m[:, None] < M)
    k = tl.load(k_ptrs, mask=offs_n[:, None] < N)
    v = tl.load(v_ptrs, mask=offs_n[:, None] < N)
    
    # Attention計算(QK^T)
    qk = tl.dot(q, k)
    qk *= 1.0 / tl.sqrt(K.to(tl.float32))
    
    # Softmax適用(数値安定性考慮)
    m_i = tl.max(qk, 1)
    qk = qk - m_i[:, None]
    p = tl.exp(qk)
    l_i = tl.sum(p, 1)
    p_norm = p / l_i[:, None]
    
    # AttentionV計算
    out = tl.dot(p_norm.to(v.dtype), v)
    
    # 結果保存
    out_ptrs = Out + pid_h * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_ok
    tl.store(out_ptrs, out, mask=offs_m[:, None] < M)

@torch.compile
def fused_transformer_attention(q, k, v):
    """torch.compileと統合されたカスタムTritonカーネル"""
    B, H, M, K = q.shape
    output = torch.empty_like(q)
    
    grid = (triton.cdiv(M, 64), H, B)
    fused_attention_kernel[grid](
        q, k, v, output,
        q.stride(2), q.stride(1), q.stride(3),
        k.stride(2), k.stride(1), k.stride(3),
        v.stride(2), v.stride(1), v.stride(3),
        output.stride(2), output.stride(1), output.stride(3),
        M, M, H, K,
        BLOCK_M=64, BLOCK_N=64
    )
    return output
参考文献

Using User-Defined Triton Kernels with torch.compile

Trition Fused Softmax

torch.compile 適用時の学習失敗

考えられる原因

1. コンパイルによる数値精度の変化

torch.compile は内部的にグラフ最適化や異なるバックエンド (例: Inductor) を使用するため、浮動小数点演算の順序や表現が微妙に変わる。これにより、特に勾配計算や正規化層 (Batch Normalization, Layer Normalization など) において、数値的な不安定性が生じ、NaN (Not a Number) や無限大 (Infinity) が発生しやすくなる。

特に、小さな値の割り算、指数関数、対数関数などが含まれる場合に顕著。

2. torch.compile のバグまたは非互換性

3. 動的な挙動は適切にコンパイルできない

モデル内に動的な制御フロー(例: if 文、for ループでテンソルの形状が変わるなど)が含まれている場合、torch.compile がその動的な挙動を正しくキャプチャできない、または静的なグラフとしてコンパイルに失敗することがある。

4. オプティマイザのステートの破損

一部のオプティマイザ (例: AdamW の beta パラメータなど) は、モデルの勾配に基づいて内部ステートを更新する。torch.compile が勾配計算の挙動を微妙に変えることで、オプティマイザの内部ステートが適切に更新されず、学習が収束しなくなることがある。

5. データ型ミスマッチまたは自動混合精度 (AMP) との相互作用

AMP と torch.compile は相性が悪い。

解決策

mode パラメータの変更

backend パラメータの変更

torch.autograd.set_detect_anomaly(True)

torch.autograd.set_detect_anomaly(True) を有効にすると、NaN や無限大が発生した場合にスタックトレースを表示し、問題の箇所を特定しやすくなる。

torch.autograd.set_detect_anomaly(True)
compiled_model = torch.compile(model)
# ... 学習ループ

勾配のクリッピング

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 例

学習率を下げる

モデルによっては 1e-7 ぐらいまで下げる必要がある。

正規化層の凍結

特に推論時に問題が発生する場合、BatchNorm層などを model.eval() に設定したままにするか、学習中でも凍結することを検討する。torch.compile は eval() モードの BatchNorm の動作を最適化する際に、学習モードと異なる挙動をすることがある。

torch.compile の適用範囲の絞り込み

特定のブロックに torch.compile を適用して問題個所の絞り込み。

torch.compile のキャッシュのクリア

学習スクリプトの実行前に、キャッシュディレクトリ(通常は ~/.cache/torch/inductor など)を削除してみるか、TORCH_COMPILE_DEBUG=1 環境変数を設定してデバッグモードで実行する。

PyTorch のバージョンアップまたはダウングレード

安定板を使う。

モデルアーキテクチャの見直し

標準的なコンポーネントを使い、独自実装をやめる。


広告
広告

カテゴリ