Megatron-LM源码系列(五): FP16使用
1. FP16参数指定
- 训练模型要使用fp16时,训练启动参数中指定
--fp16
, 对应megatron/arguments.py
中的定义如下:
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
- 在计算
lm-cross-entropy
时默认是使用fp32来计算的,在开启--fp16
选项的前提下可以通过指定--fp16-lm-cross-entropy
来使用fp16计算lm-loss-entropy
,对应megatron/arguments.py
中的定义如下:
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
- 在megatron中跟fp16还有关系的一个参数是
args.fp32_residual_connection
,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
validate_args
函数用于check参数有效性,fp16相关实现如下:
def validate_args(args, defaults={}):
......
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
......
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
......
如果指定了fp16,这里的args.fp16
为True,对应的args.params_dtype
参数类型为torch.half
。
2. ParallelAttention模块中fp16计算
2.1 训练部分
ParallelAttention中有self.query_key_value
、self.core_attention
和self.dense
等子模块,fp16对训练的影响会应用在子模块中。
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
...
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
...
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
...
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
bias=args.add_bias_linear,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs())
对于self.query_key_value
和self.dense
模块,fp16的设置能过参数中的**_args_to_kwargs()
进行传递。
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
对于self.core_attention
部分,fp16的设置是在CoreAttention
的__init__
中self.fp16 = args.fp16
。
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
...
2.2 推理部分
在ParallelAttention
模块本身中fp16会影响推理部分
class ParallelAttention(MegatronModule):
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
...
self.params_dtype = args.params_dtype
...
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
...
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
...
- 当指定了fp16以后,在
ParallelAttention
模型__init__
初始化时会设置参数类型self.params_dtype
为fp16 - 在提前分配memory时
_allocate_memory
中会用torch.empty
创建用于推理的大buffer,类型是fp16 - 在指定推理参数
inference_params
时,forward函数中会调用_allocate_memory
3. CoreAttention模块中fp16计算
当设了fp16以后,在CoreAttention
的forward计算的input就是fp16类型,在init中设置fp16 flag主要是用于计算中用到的FusedScaleMaskSoftmax
模块的输出结果类型转换。
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
...
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
...
当FusedScaleMaskSoftmax
执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax
;对于不支持的规模时,会调用forward_torch_softmax
进行模拟,输出的类型就根据self.input_in_float16
来进行cast转换。
class FusedScaleMaskSoftmax(nn.Module):
...
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
4. ColumnParallelLinear模块中fp16计算
在ColumnParallelLinear
初始化时创建Parameter中的类型直接按params_dtype(即fp16)
来设。
class ColumnParallelLinear(torch.nn.Module):
def __init__(self, ...,
params_dtype=torch.float32,
...,
):
...
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
...
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
...
5. lm-cross-entropy计算
以gpt2模型为例,在megatron/model/gpt_model.py
文件中的post_language_model_processing
函数, 如果指定了fp16_lm_cross_entropy
,那么在计算cross entropy
时会把output
先转为float32
再进行计算loss。
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!