基于Ray和vLLM构建70B+模型的开源RLHF全量训练框架

news/2024/7/10 18:38:02 标签: 开源

背景

ChatGPT 已经问世一年+了,在训练 ChatGPT 中必不可少的一环是 RLHF 训练,目前开源社区已经有了不少 RLHF 训练框架比如,TRL, DeepSpeedChat 或者最近热门的 LLaMA Factory。这些框架往往是基于 ZeRO 等并行方式,将 RLHF 算法中的四个模型切片后放到同一个 GPU 上。在模型规模越来越大的今天,这种调度方式无法满足 70B+ 甚至仅 13B+模型的全量 RLHF 训练,必须通过合并 Actor Critic 模型或者使用 LoRA 等方式妥协内存使用量。而这些PEFT的方式往往意味着模型效果的妥协。

于是乎开源项目:

https://github.com/OpenLLMAI/OpenRLHF

诞生了,我们基于 Ray 和 vLLM 重新设计了模型调度方案:

  1. 对于 7B 这种小模型,我们将所有模型放到同一张GPU上

  2. 对于 13B~34B 的中等模型,我们基于 Ray 将 PPO 中的四个模型放到不同的GPU上实现全量微调

  3. 对于 34B+的大模型,我们用 vLLM 的 TP 并行加载 Actor 模型,其他模型仍然用 Ray 的方式分散在不同的GPU上

ZeRO2 + Adam Offload + Pinned Memory

我们针对小于 34B 的模型使用 ZeRO2 + Adam Offload + Pinned Memory 的优化方案,我们的基本想法是

  1. 我们发现 RLHF 训练流程中 80% 的时间都被用于 GPT 模型的样本生成和推理,这是因为 GPT 模型的自回归解码具有 O(n^2) 复杂度,并且通常是 Memory Bound 的。

  2. 最简单的提升推理效率的方式是避免通过加大矩阵乘法的尺寸来避免 Memory Bound 和增强 GPU 计算效率,但大的矩阵乘法意味着大的batch_size,导致KV Cache对内存需求很大。

  3. 所以我们想到通过 Optimizer 的 Offload 将 Adam 优化器权重放到 CPU 内存中来节省内存,并且通过 Pinned Memory 避免梯度聚合时候的GPU-CPU通信效率问题。此时我们不仅可以用节省的内存来加大batch_size,而且可以用 ZeRO2 来避免模型切片造成的极大通信开销。

  4. 对于 13B+ 的模型我们会发现基于 ZeRO2 在 A100 的 80G 内存上无法塞下四个模型,所以我们基于 Ray 将模型分别放到不同的 GPU上。不过对于 Actor 我们会分配更多的GPU来减少 GPU 空闲。

通过这种优化策略后优化后,我们在13B模型上做测试,发现我们实现了 4倍于 DeepSpeedChat 的训练效率。

Ray + vLLM 方案架构

但是对于 34B+ 的模型我们发现即使用 Ray 把模型放到不同的卡上也没有办法放得下去

所以我们想到对于 Actor 推理模块我们基于 vLLM 的 TP 并行和 Dynamic Batching 能力做了分布式推理的优化,然后其他模块(即 Actor/Critic的训练模块和Reward/RefActor的推理模块)因为只参一次 forward 或者 backward 我们采用 ZeRO3 的方式进行并行训练。

每次 PPO 训练,vLLM 推理引擎都会收到 DeepSpeed ZeRO3 训练框架更新后的权重,我们通过 NVIDIA NCCL 高性能通信实现了这个过程。鉴于 vLLM 的高性能推理能力,我们实现的不错的性能收益。更进一步,我们可以融合 Actor 的训练节点和推理节点实现节点复用来避免 GPU 空闲,因为这两个模块并不会同时工作。

至此我们通过 Ray 和 vLLM 实现了 70B+ 模型的 RLHF训练方案,并且我们的方案是无缝兼容 Huggingface Transformers 库的,无需像 Megatron-LM 一样手动修改模型结构。

PPO Implementation Tricks

除了系统架构方面的优化,我们进一步整合了 RLHF 算法方面的优化。根据两篇 PPO 经论文:

https://arxiv.org/abs/2005.12729

https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/

PPO算法在实现细节方面有非常多的讲究和调参技巧,我们在 蜗牛在花园跑酷:如何正确复现 Instruct GPT / RLHF? 一文中论述过部分的实现细节和优化技巧。在 OpenRLHF 中我们集成了这些所有 Implementation Tricks,从而实现了 PPO 训练算法的稳定训练和收敛。

多种对齐算法支持

我们不仅实现了 PPO,而且提供了 DPO/Rejection Sampling/Conditonal SFT 等 Alignemnt 算法的支持。

详情参考 OpenRLHF 项目 Readme.md

Quick Start 快速教程

我们只需要安装好环境依赖后,使用 Ray 提交训练任务即可。OpenRLHF 的模型和数据集完美兼容 HuggingFace 格式,包括热门的 MoE 模型 Mixtral 8*7b,只需要指定模型名字或者本地目录地址即可。

址即可。

# 启动 Ray
nohup ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --block &> ray.log &

# 提交 Ray 任务
ray job submit --address="http://127.0.0.1:8265" \
    --runtime-env-json='{"working_dir": "/openrlhf", "pip": "/openrlhf/requirements.txt"}' \
    --no-wait \
    -- python3 examples/train_ppo_ray.py \
    --ref_num_nodes 1 \               # ref policy 节点数量
    --ref_num_gpus_per_node 2 \       # ref policy GPU数量
    --reward_num_nodes 1 \            # reward model 节点数量
    --reward_num_gpus_per_node 2 \    # reward model GPU数量
    --critic_num_nodes 1 \            # critic 节点数量
    --critic_num_gpus_per_node 4 \    # critic GPU数量
    --actor_num_nodes 1 \             # actor  训练节点数量
    --actor_num_gpus_per_node 4 \     # actor  训练GPU数量
    --vllm_num_engines 2 \            # actor 推理节点数量
    --vllm_tensor_parallel_size 2 \   # actor 推理GPU数量
    --pretrain meta-llama/Llama-2-70b-chat-hf \            # Actor 预训练模型
    --reward_pretrain meta-llama/Llama-2-70b-chat-hf \     # Reward 预训练模型
    --save_path /mnt/bn/wuxibin/cache/ckpt/llama_70b \     # 模型保存路径
    --micro_train_batch_size 1 \
    --train_batch_size 128 \
    --micro_rollout_batch_size 2 \
    --rollout_batch_size 1024 \
    --max_epochs 1 \
    --prompt_max_len 1024 \
    --generate_max_len 1024 \
    --zero_stage 3 \
    --bf16 \
    --actor_learning_rate 5e-7 \
    --critic_learning_rate 9e-6 \
    --init_kl_coef 0.01 \
    --prompt_data Open-Orca/OpenOrca,Dahoas/full-hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward \  # 数据集
    --prompt_data_probs 0.4,0.5,0.1 \                                                              # 数据集混合概率
    --max_samples 80000 \                                                                          # 最大样本数量
    --normalize_reward \                                                                           # Reward Normalization
    --actor_init_on_gpu \
    --adam_offload \                                             
    --flash_attn \
    --gradient_checkpointing

对于 SFT/Reward 模型的训练,我们也提供了相应的实现。只需要直接运行 deepspeed 命令即可

# Reward Model training
deepspeed ./train_rm.py \
     --save_path ./ckpt/7b_llama \
     --save_steps -1 \
     --logging_steps 1 \
     --eval_steps -1 \
     --train_batch_size 128 \
     --micro_train_batch_size 1 \
     --pretrain OpenLLMAI/Llama-2-7b-sft-model-ocra-500k \
     --bf16 \
     --max_epochs 1 \
     --max_len 2048 \
     --zero_stage 3 \
     --learning_rate 9e-6 \
     --dataset Anthropic/hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward,lmsys/chatbot_arena_conversations,openai/webgpt_comparisons \
     --dataset_probs 0.72,0.08,0.12,0.08 \
     --flash_attn \
     --gradient_checkpointing

# SFT model training
deepspeed ./train_sft.py \
    --max_len 2048 \
    --dataset Open-Orca/OpenOrca \
    --dataset_probs 1.0 \
    --train_batch_size 128 \
    --micro_train_batch_size 2 \
    --max_samples 500000 \
    --pretrain meta-llama/Llama-2-7b-hf \
    --save_path ./ckpt/7b_llama \
    --save_steps -1 \
    --logging_steps 1 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs 1 \
    --bf16 \
    --flash_attn \
    --learning_rate 5e-6 \
    --gradient_checkpointing

http://www.niftyadmin.cn/n/5339095.html

相关文章

Rockchip linux USB 驱动开发

Linux USB 驱动架构 Linux USB 协议栈是一个分层的架构,如下图 5-1 所示,左边是 USB Device 驱动,右边是 USB Host 驱动,最底层是 Rockchip 系列芯片不同 USB 控制器和 PHY 的驱动。 Linux USB 驱动架构 USB PHY 驱动开发 USB 2…

2023年12月青少年机器人技术等级考试(三级)理论综合试卷

2023年12月青少年机器人技术等级考试(三级)理论综合试卷 单选题 第 1 题 单选题 下列选项中,关于光敏电阻描述正确的是?( ) A. 光敏电阻是由导体材料制作而成 B. 光照射光敏电阻时,光照越强…

Jira 宣布Data Center版涨价5%-15%,6年内第8次提价

近日,Atlassian官方面向合作伙伴发布2024年涨价通知: 自2024年2月15日起,旗下核心产品Jira Software、Confluence、Jira Service Management的DC版本(Data Center版本)价格提高5%-15%(涨幅与坐席数阶梯相关…

设计模式篇---中介者模式

文章目录 概念结构实例总结 概念 中介者模式:用一个中介对象来封装一系列的对象交互。中介者使各对象不需要显示地相互引用,从而使其耦合松散,而且可以独立地改变它们之间的交互。 就好比世界各个国家之间可能会产生冲突,但是当产…

Java中的代理模式(二)JDK动态代理

大家好👋,我是极客涛😎,上一篇中我们对代理模式有两大类,静态代理和动态代理,对于静态代理相信大家都信手拈来。对于动态代理还有两种实现,一种是java原生的Jdk代理,一种是Cglib方式…

Flink(十四)【Flink SQL(中)查询】

前言 接着上次写剩下的查询继续学习。 Flink SQL 查询 环境准备: # 1. 先启动 hadoop myhadoop start # 2. 不需要启动 flink 只启动yarn-session即可 /opt/module/flink-1.17.0/bin/yarn-session.sh -d # 3. 启动 flink sql 的环境 sql-client ./sql-client.sh …

leetcode:三数之和---双指针

问题: 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请 你返回所有和为 0 且不重复的三元组。 注意:答案中不可以包含重复…

Redis(01)——常用指令

基础指令 select 数字:切换到其他数据库flushdb:清空当前数据库flushall:清空所有数据库dbsize:查看数据库大小exists key1[key2 …]:判断当前的key是否存在keys *:查看所有的keyexpire key 时间&#xff…