自己写
训练并行的原因#
单个 GPU 计算有上限,通过并行处理多 GPU 协同工作,提升训练效率。
训练并行维度#
数据并行(DP)#
对于 mp_size = 1 ,batch = 64 ,dp_size = 4
相当于把一个 batch 分为 4 块,每一块处理 16 个样本,每一块上都有完整的模型参数。
更新参数,是将 4 块的平均 Loss 求和,然后再进行参数更新(实际还是一个 batch 的 Loss 求平均 * 学习率)
在每算完一个 batch 都需要进行模型参数同步
模型并行(MP)#
对于 mp_size = 4 ,batch = 64 ,dp_size = 1
相当于每一块上需要对完整 batch 的数据进行计算,每一块只有四分之一的模型参数,分别更新每一块的权重,在前向计算和反向计算时也需要通信 ****
MP 中没有全局参数一致性的需求,每个 rank 的参数本来就不同,所以不需要同步更新
MP&DP#
在实际应用过程中 MP 和 DP 是可以一起使用的,但是还要遵守上面的规则。
GPT 写
训练并行总结:DP / MP / DP+MP#
1. 为什么需要训练并行?#
单个 GPU 的显存与计算能力有限,难以承载大型深度学习模型。
通过使用多个 GPU,可以:
-
分摊计算负载
-
扩展可训练模型规模
-
提升训练速度
深度学习训练的并行一般分为两类:
-
数据并行(Data Parallel, DP)
-
模型并行(Model Parallel, MP)
它们也可以组合为混合并行(DP + MP)。
2. 数据并行(DP)#
假设:
-
dp_size = 4
-
mp_size = 1
-
batch = 64
数据并行会将一个 batch 划分为多个子 batch:
-
GPU0:16 个样本
-
GPU1:16 个样本
-
GPU2:16 个样本
-
GPU3:16 个样本
每个 GPU 都持有完整的模型参数。
每个 GPU 独立计算本地样本的梯度:
g0 = ∇θ L0
g1 = ∇θ L1
g2 = ∇θ L2
g3 = ∇θ L3
然后通过 Allreduce (sum) 同步梯度:
g_global = g0 + g1 + g2 + g3
最终会除以 dp_size 得到平均梯度:
g_avg = g_global / 4
每个 GPU 使用相同的梯度更新参数:
θ ← θ - lr * g_avg
所以 DP 的本质是:
-
数据切分
-
梯度同步
-
参数保持一致
3. 模型并行(MP)#
假设:
-
mp_size = 4
-
dp_size = 1
-
batch = 64
模型并行将模型参数按某维度切分到不同 GPU。
例如全连接层的权重:
W = [ W0 | W1 | W2 | W3 ]
每个 GPU 会看到完整 batch 的输入 X(64 × d_in)。
每个 GPU 只执行自己的矩阵乘法:
GPU0: Y0 = X @ W0
GPU1: Y1 = X @ W1
GPU2: Y2 = X @ W2
GPU3: Y3 = X @ W3
完整输出为拼接:
Y = [Y0 | Y1 | Y2 | Y3]
因此前向传播必须通过 Allgather 拼接输出。
反向传播#
每个 GPU 计算自己负责的梯度:
dW_i = X^T @ dY_i
但输入梯度 dX 必须由所有 GPU 的贡献求和:
dX = dX0 + dX1 + dX2 + dX3
因此需要 Allreduce (sum)。
模型并行的核心思想#
MP 不是把模型分块独立训练,而是:
多个 GPU 协同完成一次前向和反向传播,需要同步激活与梯度。
必须用:
-
Allgather(拼接激活)
-
Allreduce(求和输入梯度)
-
ReduceScatter(Megatron-LM 的优化)
MP 的特点#
-
参数切分
-
数据完整(每个 GPU 看到整个 batch)
-
各 GPU 只更新自己的参数分片
-
参数本来就是不同的,无需保持一致
-
但激活 / 梯度必须通信同步
4. 混合并行(DP + MP)#
真实大模型训练中通常同时使用 DP 与 MP。
GPU 集群结构可以看成一个 2D 网格:
-
行:DP 组(处理不同数据)
-
列:MP 组(模型切分)
通信模式:
-
DP 维度:梯度同步(Allreduce)
-
MP 维度:协同计算(Allgather / Allreduce / ReduceScatter)
DP+MP 是 GPT-3、PaLM、Megatron-LM 等大模型训练的核心。
5. 总结#
数据并行(DP)#
-
数据切分
-
模型复制
-
梯度同步
-
参数一致
模型并行(MP)#
-
模型切分
-
数据完整
-
激活 / 梯度通信
-
参数分片独立更新
DP + MP#
-
二维并行模式
-
大模型训练的基础