自动销售设备 发表于 2025-2-7 14:06:00

聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型

概述

首发自个人公众号:阿郎小哥的随笔驿站
DeepSeek R1系列建议阅读之前的系列文章:
聊聊DeepSeek R1的一些总结
聊聊DeepSeek R1的开源复现库——Open R1之合成数据
聊聊DeepSeek R1的知识蒸馏与应用思考
简介

GRPO 是一种在线学习算法,这意味着它通过在训练期间使用受训模型自身生成的数据来迭代改进。GRPO 目标背后的直觉是最大化生成补全的优势,同时确保模型保持接近参考策略。
GRPO 的四个主要步骤:生成补全、计算优势、估计 KL 散度和计算损失。

与传统的RL方法不同,后者通常依赖外部评估者(批评者)来引导学习,GRPO通过评估一组响应之间的相对关系来优化模型。这种方法提高了训练效率,使GRPO在需要复杂问题解决和长链思维的推理任务中表现尤为出色。
步骤分解

步骤1:选择查询
• 从训练数据集$ P(Q) $中选择一个查询$ (q) $。
• 示例:假设查询是“8 + 5的和是多少?”
步骤2:生成一组响应
• 模型针对该查询生成一组$ G $个响应。
• 示例:模型生成以下响应:
• o1:“答案是13。”
• o2:“十三。”
• o3:“是12。”
• o4:“和是13。”
步骤3:计算每个响应的奖励
• 什么是奖励?奖励通过量化响应的质量来引导模型的学习。
• GRPO中的奖励类型:
• 准确性奖励:基于响应的正确性(例如,解答数学题)。
• 格式奖励:确保响应符合结构化要求(例如,推理过程需要包含在标签中)。
• 语言一致性奖励:惩罚语言混杂或格式不一致的响应。
• 根据每个响应的好坏,赋予一个奖励($ r_i $)。
例如,奖励可能取决于:
• 准确性:答案是否正确?
• 格式:响应是否结构良好?
示例:
• r1 = 1.0(正确且格式良好)
• r2 = 0.9(正确但较不正式)
• r3 = 0.0(错误答案)
• r4 = 1.0(正确且格式良好)
步骤4:比较响应(群体优势)
• 计算每个响应相对于群体的优势$ (A_i) $,paper中相关术语如下:

用简单的方式理解,就是这样:

• 比较结果优于群体平均水平的响应会获得正分,而表现较差的响应会得到负分。
• 这种方式在群体内部激发竞争,推动模型生成更好的响应。
步骤5:使用裁剪更新策略

示例:如果新策略开始给o1分配过高的概率,裁剪机制确保不会过度强调这个响应。
这种方式保证了即使在像推理这样复杂的任务中,策略优化也能保持稳定和可靠。
步骤6:通过KL散度惩罚偏差

GRPO实现

Open R1

在Open R1的复现路径中

实现了基于GRPO算法的训练,脚本如下
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=7 src/open_r1/grpo.py --config recipes/qwen/Qwen2.5-1.5B-Instruct/grpo/confg_full.yamlconfg_full.yaml
# 基座模型model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5Bmodel_revision: maintorch_dtype: bfloat16# 训练数据集dataset_name: AI-MO/NuminaMath-TIRdataset_configs:- all# Num processes is less by 1 as vLLM is using 1 GPUnum_processes: 7# GRPO训练器参数bf16: trueuse_vllm: truevllm_device: autovllm_gpu_memory_utilization: 0.7do_eval: trueeval_strategy: stepseval_steps: 100gradient_accumulation_steps: 16gradient_checkpointing: truegradient_checkpointing_kwargs:use_reentrant: falsehub_model_id: Qwen2.5-1.5B-Open-R1-GRPOhub_strategy: every_savelearning_rate: 2.0e-05log_level: infologging_steps: 10logging_strategy: stepslr_scheduler_type: cosinemax_prompt_length: 512max_completion_length: 1024max_steps: -1num_train_epochs: 1output_dir: data/Qwen2.5-1.5B-Open-R1-GRPOoverwrite_output_dir: trueper_device_eval_batch_size: 4   per_device_train_batch_size: 1push_to_hub: truereport_to:- wandbsave_strategy: "no"seed: 42warmup_ratio: 0.1Open R1提供了grpo算法的实现——grpo.py,删减了部分无关代码,关键的程序逻辑如下:
@dataclassclass GRPOScriptArguments(ScriptArguments):    reward_funcs: list = field(      default_factory=lambda: ["accuracy", "format"],      metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},    )def accuracy_reward(completions, solution, **kwargs):    """Reward function that checks if the completion is the same as the ground truth."""    contents = ["content"] for completion in completions]    rewards = []    for content, sol in zip(contents, solution):      gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=)      if len(gold_parsed) != 0:            # We require the answer to be provided in correct latex (no malformed operators)            answer_parsed = parse(                content,                extraction_config=[                  LatexExtractionConfig(                        normalization_config=NormalizationConfig(                            nits=False,                            malformed_operators=False,                            basic_latex=True,                            equations=True,                            boxed=True,                            units=True,                        ),                        # Ensures that boxed is tried first                        boxed_match_priority=0,                        try_extract_without_anchor=False,                  )                ],                extraction_mode="first_match",            )            # Reward 1 if the content is the same as the ground truth, 0 otherwise            reward = float(verify(answer_parsed, gold_parsed))      else:            # If the gold solution is not parseable, we reward 1 to skip this example            reward = 1.0            print("Failed to parse gold solution: ", sol)      rewards.append(reward)    return rewardsdef format_reward(completions, **kwargs):    """Reward function that checks if the completion has a specific format."""    pattern = r"^<think>.*?</think><answer>.*?</answer>$"    completion_contents = ["content"] for completion in completions]    matches =     return reward_funcs_registry = {    "accuracy": accuracy_reward,    "format": format_reward,}SYSTEM_PROMPT = (    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "    "<think> reasoning process here </think><answer> answer here </answer>")def main(script_args, training_args, model_args):       # Load the dataset    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)    # Get reward functions    reward_funcs = for func in script_args.reward_funcs]    # Format into conversation    def make_conversation(example):      return {            "prompt": [                {"role": "system", "content": SYSTEM_PROMPT},                {"role": "user", "content": example["problem"]},            ],      }    dataset = dataset.map(make_conversation)    for split in dataset:      if "messages" in dataset.column_names:            dataset = dataset.remove_columns("messages")    logger.info("*** Initializing model kwargs ***")    torch_dtype = (      model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)    )    model_kwargs = dict(      revision=model_args.model_revision,      trust_remote_code=model_args.trust_remote_code,      attn_implementation=model_args.attn_implementation,      torch_dtype=torch_dtype,      use_cache=False if training_args.gradient_checkpointing else True,    )    training_args.model_init_kwargs = model_kwargs    #############################    # Initialize the GRPO trainer    #############################    trainer = GRPOTrainer(      model=model_args.model_name_or_path,      reward_funcs=reward_funcs,      args=training_args,      train_dataset=dataset,      eval_dataset=dataset if training_args.eval_strategy != "no" else None,      peft_config=get_peft_config(model_args),      callbacks=get_callbacks(training_args, model_args),    )    ###############    # Training loop    ###############    logger.info("*** Train ***")    checkpoint = None    if training_args.resume_from_checkpoint is not None:      checkpoint = training_args.resume_from_checkpoint    elif last_checkpoint is not None:      checkpoint = last_checkpoint    train_result = trainer.train(resume_from_checkpoint=checkpoint)    metrics = train_result.metrics    metrics["train_samples"] = len(dataset)    trainer.log_metrics("train", metrics)    trainer.save_metrics("train", metrics)    trainer.save_state()    ##################################    # Save model and create model card    ##################################    trainer.save_model(training_args.output_dir)    # Save everything else on main process    kwargs = {      "dataset_name": script_args.dataset_name,      "tags": ["open-r1"],    }    if trainer.accelerator.is_main_process:      trainer.create_model_card(**kwargs)      # Restore k,v cache for fast inference      trainer.model.config.use_cache = True      trainer.model.config.save_pretrained(training_args.output_dir)    ##########    # Evaluate    ##########    if training_args.do_eval:      logger.info("*** Evaluate ***")      metrics = trainer.evaluate()      metrics["eval_samples"] = len(dataset)      trainer.log_metrics("eval", metrics)      trainer.save_metrics("eval", metrics)    #############    # push to hub    #############    if training_args.push_to_hub:      logger.info("Pushing to hub...")      trainer.push_to_hub(**kwargs)if __name__ == "__main__":    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))    script_args, training_args, model_args = parser.parse_args_and_config()    main(script_args, training_args, model_args)代码分析如下:
首先就是加载数据集,但数据集在加载时,会有指定的提示词,即代码中的make_conversation函数,该函数构造指定的prompt引导模型的输出,格式如下:
{    "prompt": [      {"role": "system", "content": SYSTEM_PROMPT},      {"role": "user", "content": example["problem"]},    ],}对于SYSTEM_PROMPT,描述如下:
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "    "<think> reasoning process here </think><answer> answer here </answer>"总的来说就是,引导模型先思考推理过程,再按格式将推理过程与回复放入指定标签<think>、<answer>内。
接下来是reward函数,grpo算法有两种奖励:准确性奖励与格式正确奖励;如下
def accuracy_reward(completions, solution, **kwargs):    """Reward function that checks if the completion is the same as the ground truth."""    contents = ["content"] for completion in completions]    rewards = []    for content, sol in zip(contents, solution):      gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=)      if len(gold_parsed) != 0:            # We require the answer to be provided in correct latex (no malformed operators)            answer_parsed = parse(                content,                extraction_config=[                  LatexExtractionConfig(                        normalization_config=NormalizationConfig(                            nits=False,                            malformed_operators=False,                            basic_latex=True,                            equations=True,                            boxed=True,                            units=True,                        ),                        # Ensures that boxed is tried first                        boxed_match_priority=0,                        try_extract_without_anchor=False,                  )                ],                extraction_mode="first_match",            )            # Reward 1 if the content is the same as the ground truth, 0 otherwise            reward = float(verify(answer_parsed, gold_parsed))      else:            # If the gold solution is not parseable, we reward 1 to skip this example            reward = 1.0            print("Failed to parse gold solution: ", sol)      rewards.append(reward)    return rewardsdef format_reward(completions, **kwargs):    """Reward function that checks if the completion has a specific format."""    pattern = r"^<think>.*?</think><answer>.*?</answer>$"    completion_contents = ["content"] for completion in completions]    matches =     return reward_funcs_registry = {    "accuracy": accuracy_reward,    "format": format_reward,}最后就是训练,GRPOTrainer是transformers库提供的基于Trainer的训练类,传入指定的参数即可实现基于GRPO算法的实现;其中比较关键的是reward、train_dataset。
############################## Initialize the GRPO trainer#############################trainer = GRPOTrainer(    model=model_args.model_name_or_path,    reward_funcs=reward_funcs,    args=training_args,    train_dataset=dataset,    eval_dataset=dataset if training_args.eval_strategy != "no" else None,    peft_config=get_peft_config(model_args),    callbacks=get_callbacks(training_args, model_args),)计算训练的checkpoint与循环周期,则会在Trainer类中通过gradient_accumulation_steps(梯度累积步数)、num_train_epochs(训练轮数)以及 per_device_train_batch_size(每个设备的训练批次大小)这些参数计算训练周期。
################ Training loop###############logger.info("*** Train ***")checkpoint = Noneif training_args.resume_from_checkpoint is not None:    checkpoint = training_args.resume_from_checkpointelif last_checkpoint is not None:    checkpoint = last_checkpointtrain_result = trainer.train(resume_from_checkpoint=checkpoint)小结

总的来说,Open R1的GRPO训练,是基于GRPOTrainer指定prompt/dataset与reward等参数实现GRPO的训练。也就是说,在指定的训练数据集下,通过prompt引导模型的输出,然后基于grpo算法及其reward对 模型的输出与训练数据集的output 做奖惩打分(通过KL散度比较),计算loss,再反向传播。循环反复;最终完成模型的RL训练,达到让模型能做到CoT式的回复,即生成补全、计算优势、估计 KL 散度和计算损失的步骤,如最开始的图所示。
对于GRPOTrainer类的源码及文档可参考:

[*]grpo_trainer 源码
[*]grpo_trainer github文档
[*]grpo_trainer huggingface文档
[*]DeepSeek背后的数学:深入解析GRPO
首发自个人公众号:阿郎小哥的随笔驿站
页: [1]
查看完整版本: 聊聊GRPO算法——从Open R1来看如何训练DeepSeek R1模型