NVIDIA の AMP が遅くなる理由
この記事の要約
- 理論TFLOPS=実効スループットではない。Transformerはデータ移動支配であり、AMP だけでは高速化限界がある
- Tensor Coreを活かすには「8の倍数」「大行列」「Fusion済み」の三要件を同時に満たす必要がある
- LayerNorm/Softmax など AMP が FP32 に昇格する演算を半精度対応へ置換すると更に 5–15 %のFLOPS向上が見込める
- FlashAttention v2・Transformer Engine・Triton生成カーネルなどを取り入れる

赤で囲った部分が AMP で高速化される
出典:TENSOR CORE DL PERFORMANCE GUIDE. Michael Andersch et al. p. 26. https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf
NVIDIA の GPU の Tensor コアはカタログスペック上は、FP32 比で 10~30 倍の TFLOPS が記載されているが、実際は2倍程度しか速くならない。Tensor コアの実効 TFLOPS は、最適化してもカタログスペックの 30 %が限界だ。以下の要素がその原因だ。
メモリアクセスがボトルネックになっている
- Transformer は行列積以外に大規模なテンソル再配列・リダクションを含み、学習全体ではメモリバウンド
- Multi-Head Attention の小行列問題:Head毎の(Q,K,V)はサイズが小さく、計算よりロードが支配的
- Rooflineモデルで見ると多くの層が I<𝐼_opt 領域に入り、FLOPS ではなく HBM 帯域で頭打ちとなる
- BERT Encoder層でもGPU演算ピークの約30%しか活用できていない
対策
- FlashAttention の block-wise kernel でスレッド再利用&データ再利用率を上げる
AMPが強制的にFP32へ昇格させる演算
LayerNorm/Softmax/Sum/各種 loss 関数/sin, cos, tan, log などは Tensor コアが使われず、追加のカーネル呼び出しが発生するために遅い。
対策
- LayerNorm に 半精度実装(SLaNC等) や tanh を使う
- torch.compile を使う
- Transformer Engine を使う
Tensor Core が利用されない形状やメモリレイアウト
- FP16/BF16 GEMM(General Matrix Multiply=行列乗算)は各次元が8の倍数で Tensor Core が有効
- バッチサイズ、シーケンス長、隠れ次元が8または16を外れるとCU Coreフォールバックで4–8×低速
対策
畳み込み
- 畳み込みでは (B, C, H, W) ではなく (B, H, W, C) を使う(2倍速い)
- CNN の FLOPS の計算方法
- 入力幅や高さではなく、バッチサイズ(N)* 出力高さ(P)* 出力幅(Q)の大きさが性能を左右する
- 畳み込みでは H, W, C を変更するのは難しいのでバッチサイズでハードウェア性能の調整をする
- フィルターサイズも大きい方がパフォーマンスが出やすい
- 出力チャネル数は 64 以上が推奨で、256 以上数値を増やしても FLOPS は横ばい
カーネル起動回数増加とスケジューリング遅延が重なり、演算密度(Arithmetic Intensity)が理論前提を下回る
- 小サイズ GEMM や頻繁なカーネル呼び出しで SM 占有率が低下し、Tensor Core が待機状態になる
- Layer-wise pipeline 非最適・同期ポイント過多でも FLOPS ロスを招く
profiler で計測して、「kernel 数 >> 層数」の場合はカーネル起動過多。
対策
- TorchScript/Fuser
- Tritonカーネル
- PyTorch 2.1 SDPA
- torch.compile を使う
DataLoader 待ち
DataLoader の pin_memory や num_workers を設定しないと遅い。
プロファイラ
from torch.profiler import profile, ProfilerActivity, record_function
with profile(activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for step, data in enumerate(train_loader):
with record_funtion('inference'):
y = model(data)
loss = criterion(y, data)
with record_function('backward'):
loss.backward()
with record_function('step'):
optimizer.step()
print(prof.key_averages().table(sort_by = "cuda_time", row_limit = 10))
PyTorch Profiler With TensorBoard
AMP の使い方
torch.cuda.amp ではなく torch.amp を使う。
from torch.amp import autocast, GradScaler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = GradScaler('cuda')
# 学習ループ
for data, target in dataloader:
optimizer.zero_grad()
with autocast('cuda'):
# autocastコンテキスト内で計算
output = model(data)
loss = nn.functional.cross_entropy(output, target)
# スケーラーで勾配をスケーリング
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
参考文献
Data Movement Is All You Need: A Case Study on Optimizing Transformers
NVIDIA Deep Learning Performance
NVIDIA Deep Learning Performance Optimizing Performance
- Linear/Fully-Connected Layers User's Guide
- Convolutional Layers User's Guide
- Recurrent Layers User's Guide
- Memory-Limited Layers User's Guide
Tips for Optimizing GPU Performance Using Tensor Cores
TENSOR CORE DL PERFORMANCE GUIDE
Pytorch Performance Tuning Guide
Tempo: Accelerating Transformer-Based Model Training through Memory Footprint Reduction
Attention Mechanism for Transformer Models with CUDA
Dynamic Stashing Quantization for Efficient Transformer Training
Transformer Engine ではじめる FP8 Training (導入編)
Transformer Engine Performance Optimizations
Fusion
Fusion は複数の演算処理を統合して単一の高効率なカーネルとして実行する。以下の効果がある。
- メモリアクセス回数の削減
- カーネル起動オーバーヘッドの削減
- レジスタやSRAMの再利用による帯域幅効率の向上
Fusion戦略の分類
種類 | 対象演算 | 効果 | 実装 |
---|---|---|---|
Vertical | Linear→GELU→Add | カーネル起動削減 | torch.compile, TorchScript |
Horizontal | 並列Linear演算 | メモリ帯域効率向上 | torch.compile |
Attention | QKV投影→Attention→Output | メモリアクセス最適化 | Flash Attention, Transformer Engine |
LayerNorm | Add→LayerNorm→Linear | 中間結果削減 | Transformer Engine, カスタム実装 |
Epilogue | GEMM→Bias→Activation | レジスタ再利用 | torch.compile, Triton |
実装
1. TorchScript + JIT Compilation(torch.compile の使用を推奨)
コードサンプル
import torch
# Fusionが適用される例
@torch.jit.script
def fused_attention_ffn(x, weight_q, weight_k, weight_v, ffn_weight):
# Attention計算
q = torch.matmul(x, weight_q)
k = torch.matmul(x, weight_k)
v = torch.matmul(x, weight_v)
# Scaled Dot-Product Attention
scale = 1.0 / (q.size(-1) ** 0.5)
scores = torch.matmul(q * scale, k.transpose(-2, -1))
attn_weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# FFN with GELU fusion
ffn_output = torch.nn.functional.gelu(torch.matmul(attn_output, ffn_weight))
return ffn_output
# JITコンパイル
compiled_model = torch.jit.script(fused_attention_ffn)
torch.compile は以下の最適化を自動適用する。
- Vertical Fusion: Linear→GELU→Dropoutの連鎖統合
- Horizontal Fusion: 複数の並列演算の統合
- Epilogue Fusion: GEMMの後続演算統合
参考文献
Optimizing CUDA Recurrent Neural Networks with TorchScript
Optimizing models using the PyTorch JIT
2. torch.compile(PyTorch 2.0+)
triton が必要になる。Windows の場合は triton-windows からインストール。
よくあるミスは Simple Workarounds を参照。
torch.compile を使う最大の利点はメモリリークを指摘してくれることだ。forward の中で torch.arrange, torch.cat, torch.Tensor([...]), tensor.clone(), torch.randn_like, torch.rand を呼び出したり、キャッシュ変数に self.register_buffer(..., persistent=False) を適用していなかったりすると警告(will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy)がでる。この警告は optimizer.zero_grad(set_to_none=True) にすると再現できる。
Windows にインストール
バージョンを調べる
python -m pip index versions triton-windows
インストール
python -m pip install triton-windows==3.3.1.post19
MSVC も必要になる。詳細は How to use torch.compile on Windows を参照。vcvars64.bat は以下のディレクトリの可能性がある。
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvars64.bat
サンプルコード
import torch
model = torch.compile(model, mode='max-autotune') # 最大Fusion適用
動的バッチサイズやシーケンス長に対応
model = torch.compile(
model,
mode='max-autotune',
dynamic=True, # 動的形状サポート
fullgraph=True # グラフ全体の最適化
)
torch._dynamo.decorators.mark_static_address の使い方は decorators.py を参照。
パラメータのマッピング
"L['self'].param_groups[0]['params'][0].grad" のようなメッセージは分かりにくいので、モデルロード後にマッピングを取得する。
param_list = list(model.named_parameters())
for i, param in enumerate(optimizer.param_groups[0]['params']):
for name, p in param_list:
if param is p:
print(f"[{i}] {name}")
cl error
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: RuntimeError: Compiler: cl is not found.
上記のエラーは環境変数に以下のパスを追加する。追加したらシェルや VS Code を再起動して、「cl」でパスが通っているか確認する。
PATH
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\bin\Hostx64\x64\
INCLUDE
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\ C:\Program Files (x86)\Windows Kits\10\Include\10.0.26100.0\ucrt\ C:\Program Files (x86)\Windows Kits\10\Include\10.0.26100.0\shared\
LIB
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\lib\ C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\lib\x64\store\ C:\Program Files (x86)\Windows Kits\10\Lib\10.0.26100.0\um\x64\ C:\Program Files (x86)\Windows Kits\10\Lib\10.0.26100.0\ucrt\x64\ C:\Users\[ユーザー名]\AppData\Local\Programs\Python\Python[バージョン]\libs\
未解決の外部シンボルは venv/lib/site-packages/torch/_inductor/cpp_builder.py の 787 行目に以下を追加。
libraries.append("kernel32") libraries.append("runtimeobject")
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x83 in position 135: invalid start byte
シェルのロケールが cp932 になっていることが原因。venv/lib/site-packages/torch/_inductor/cpp_builder.py の 63 行目を以下のように編集する。
import locale
SUBPROCESS_DECODE_ARGS = (locale.getpreferredencoding(),) if _IS_WINDOWS else ()
そのほか
torch.set_float32_matmul_precision('high')
import os
os.environ["TRITON_LOG_LEVEL"] = "0"
os.environ["TRITON_LOG_LEVEL"] = "0" は以下の警告を非表示にする。他の警告も非表示にされる恐れがあるので、デバッグ終了後に設定する。以下の警告は SM 数が 80 以下の GPU で実行していると出る。
W0802 04:00:15.720576 21708 Lib\site-packages\torch\_inductor\utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode
参考文献
Accelerating PyTorch Models: Inside torch.compile’s Kernel Optimization
Accelerated PyTorch inference with torch.compile on AWS Graviton processors
3. Flash Attention統合
torch.nn.functional.scaled_dot_product_attention を使う。
torch.nn.MultiheadAttention も参照。
コードサンプル
import torch
import torch.nn.functional as F
# Flash Attention使用例
def flash_attention_block(q, k, v, causal=True):
# PyTorch 2.0以降ではScaled Dot Product Attentionが自動的にFlash Attentionを利用
return F.scaled_dot_product_attention(
q, k, v,
is_causal=causal,
enable_gqa=False
)
class OptimizedTransformer(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.attention_proj = torch.nn.Linear(config.d_model, 3 * config.d_model)
self.output_proj = torch.nn.Linear(config.d_model, config.d_model)
@torch.compile
def forward(self, x):
B, L, D = x.shape
qkv = self.attention_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# Flash Attentionが自動適用される
attn_out = flash_attention_block(q, k, v)
return self.output_proj(attn_out)
参考文献
FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention
4. Transformer Engine(NVIDIA)
コードサンプル
import transformer_engine.pytorch as te
class FusedTransformerLayer(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads):
super().__init__()
self.attention = te.MultiheadAttention(
hidden_size,
num_attention_heads,
fuse_qkv_params=True # QKV投影統合
)
self.mlp = te.LayerNormMLP(
hidden_size,
4 * hidden_size,
bias=True,
activation="gelu",
normalization="LayerNorm" # LayerNorm + MLP統合
)
def forward(self, x):
# 内部で最適化されたFusionカーネル使用
attn_out = self.attention(x)
return self.mlp(attn_out)
# FP8混合精度Fusion適用
with te.fp8_autocast(enabled=True):
model = FusedTransformerLayer(768, 12)
output = model(input_tensor)
5. カスタムTritonカーネル
コードサンプル
import triton
import triton.language as tl
@triton.jit
def fused_attention_kernel(
Q, K, V, Out,
stride_qm, stride_qh, stride_qk,
stride_km, stride_kh, stride_kk,
stride_vm, stride_vh, stride_vk,
stride_om, stride_oh, stride_ok,
M, N, H, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
"""Fused Attention + LayerNorm + FFN カーネル"""
# プログラムID取得
pid_m = tl.program_id(0)
pid_h = tl.program_id(1)
# ブロック範囲計算
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# Query, Key, Value読み込み
q_ptrs = Q + pid_h * stride_qh + offs_m[:, None] * stride_qm + offs_n[None, :] * stride_qk
k_ptrs = K + pid_h * stride_kh + offs_n[:, None] * stride_km + offs_n[None, :] * stride_kk
v_ptrs = V + pid_h * stride_vh + offs_n[:, None] * stride_vm + offs_n[None, :] * stride_vk
q = tl.load(q_ptrs, mask=offs_m[:, None] < M)
k = tl.load(k_ptrs, mask=offs_n[:, None] < N)
v = tl.load(v_ptrs, mask=offs_n[:, None] < N)
# Attention計算(QK^T)
qk = tl.dot(q, k)
qk *= 1.0 / tl.sqrt(K.to(tl.float32))
# Softmax適用(数値安定性考慮)
m_i = tl.max(qk, 1)
qk = qk - m_i[:, None]
p = tl.exp(qk)
l_i = tl.sum(p, 1)
p_norm = p / l_i[:, None]
# AttentionV計算
out = tl.dot(p_norm.to(v.dtype), v)
# 結果保存
out_ptrs = Out + pid_h * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_ok
tl.store(out_ptrs, out, mask=offs_m[:, None] < M)
@torch.compile
def fused_transformer_attention(q, k, v):
"""torch.compileと統合されたカスタムTritonカーネル"""
B, H, M, K = q.shape
output = torch.empty_like(q)
grid = (triton.cdiv(M, 64), H, B)
fused_attention_kernel[grid](
q, k, v, output,
q.stride(2), q.stride(1), q.stride(3),
k.stride(2), k.stride(1), k.stride(3),
v.stride(2), v.stride(1), v.stride(3),
output.stride(2), output.stride(1), output.stride(3),
M, M, H, K,
BLOCK_M=64, BLOCK_N=64
)
return output
参考文献
Using User-Defined Triton Kernels with torch.compile
torch.compile 適用時の学習失敗
考えられる原因
1. コンパイルによる数値精度の変化
torch.compile は内部的にグラフ最適化や異なるバックエンド (例: Inductor) を使用するため、浮動小数点演算の順序や表現が微妙に変わる。これにより、特に勾配計算や正規化層 (Batch Normalization, Layer Normalization など) において、数値的な不安定性が生じ、NaN (Not a Number) や無限大 (Infinity) が発生しやすくなる。
特に、小さな値の割り算、指数関数、対数関数などが含まれる場合に顕著。
2. torch.compile のバグまたは非互換性
3. 動的な挙動は適切にコンパイルできない
モデル内に動的な制御フロー(例: if 文、for ループでテンソルの形状が変わるなど)が含まれている場合、torch.compile がその動的な挙動を正しくキャプチャできない、または静的なグラフとしてコンパイルに失敗することがある。
4. オプティマイザのステートの破損
一部のオプティマイザ (例: AdamW の beta パラメータなど) は、モデルの勾配に基づいて内部ステートを更新する。torch.compile が勾配計算の挙動を微妙に変えることで、オプティマイザの内部ステートが適切に更新されず、学習が収束しなくなることがある。
5. データ型ミスマッチまたは自動混合精度 (AMP) との相互作用
AMP と torch.compile は相性が悪い。
解決策
mode パラメータの変更
- reduce-overhead:最も安全
- default:2番目に安全
- max-autotune:速いがコンパイル時間が長く不安定
backend パラメータの変更
- torch.compile(model, backend="inductor") (デフォルト): 最も一般的で高性能なバックエンド
- torch.compile(model, backend="aot_eager") または backend="eager": コンパイルの恩恵はなくなるが、デバッグが容易になる
torch.autograd.set_detect_anomaly(True)
torch.autograd.set_detect_anomaly(True) を有効にすると、NaN や無限大が発生した場合にスタックトレースを表示し、問題の箇所を特定しやすくなる。
torch.autograd.set_detect_anomaly(True)
compiled_model = torch.compile(model)
# ... 学習ループ
勾配のクリッピング
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 例
学習率を下げる
モデルによっては 1e-7 ぐらいまで下げる必要がある。
正規化層の凍結
特に推論時に問題が発生する場合、BatchNorm層などを model.eval() に設定したままにするか、学習中でも凍結することを検討する。torch.compile は eval() モードの BatchNorm の動作を最適化する際に、学習モードと異なる挙動をすることがある。
torch.compile の適用範囲の絞り込み
特定のブロックに torch.compile を適用して問題個所の絞り込み。
torch.compile のキャッシュのクリア
学習スクリプトの実行前に、キャッシュディレクトリ(通常は ~/.cache/torch/inductor など)を削除してみるか、TORCH_COMPILE_DEBUG=1 環境変数を設定してデバッグモードで実行する。
PyTorch のバージョンアップまたはダウングレード
安定板を使う。
モデルアーキテクチャの見直し
標準的なコンポーネントを使い、独自実装をやめる。