diff --git a/README.md b/README.md index ae28e39..468cc88 100644 --- a/README.md +++ b/README.md @@ -42,14 +42,27 @@ - [Models](#Models) - [Datasets](#Datasets) - [Star History](#Star-History) +- [Join Us](#Join-Us) ## News -🔥🔥🔥 [2024/01/17] We released MFTCoder v0.3.0, mainly for MFTCoder-accelerate. It now supports new models like Mixtral(MoE), DeepSeek-coder, chatglm3. It supports FSDP as an option. It also supports Self-paced Loss as a solution for convergence balance in Multitask Fine-tuning. +🔥🔥🔥 [2024/10/31] We released **MFTCoder v0.5** mainly for MFTCoder-accelerate, which is now supporting preference alignment methods like **DPO/RPO/ORPO** in the new **xxpo** module, adding full-parameter continue-training in the additional **mpt** module along with its **offline_tokenization** module, updating selfpaced method to new convergence balance(CoBa) method for MFT in the original **pefts** module. -🔥🔥🔥 [2024/01/17] [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) has been released, achieving a pass@1 (greedy decoding) score of 78.7% on HumanEval. It lists as top-1 LLM on Bigcode Leardboard in terms of win-rate, the official result is going to be published later. +🔥🔥🔥 [2024/10/31] Our paper [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) has been accepted by EMNLP-2024, which achieves balanced convergence across various tasks. -🔥🔥🔥 [2024/01/17] [CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8X7B) has been released, achieving a pass@1 (greedy decoding) score of 56.1% on HumanEval. +🔥🔥🔥 [2024/05/20] We released **MFTCoder v0.4**, mainly for MFTCoder-accelerate. It supports **QLoRA + DeepSpeed Zero3** and **QLoRA + FSDP** as options allowing you training very large models. It now supports new models like Qwen2, Qwen2-MoE, Starcoder2, Gemma, etc. + +🔥🔥🔥 [2024/05/20] Our paper [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) has been accepted by KDD2024. + +🔥🔥🔥 [2024/05/20] [CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) has been released, achieving a pass@1 (greedy decoding) score of 73.2% on HumanEval. + +🔥🔥 [2024/01/30] The model [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) fine-tuned with MFTCoder ranks first in HuggingFace [Big Code Models LeaderBoard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard) + +🔥🔥 [2024/01/17] We released MFTCoder v0.3.0, mainly for MFTCoder-accelerate. It now supports new models like Mixtral(MoE), DeepSeek-coder, chatglm3. It supports FSDP as an option. It also supports Self-paced Loss as a solution for convergence balance in Multitask Fine-tuning. + +🔥🔥 [2024/01/17] [CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) has been released, achieving a pass@1 (greedy decoding) score of 78.7% on HumanEval. It lists as top-1 LLM on Bigcode Leardboard in terms of win-rate, the official result is going to be published later. + +🔥🔥 [2024/01/17] [CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8X7B) has been released, achieving a pass@1 (greedy decoding) score of 56.1% on HumanEval. 🔥🔥 [2023/11/07] [MFTCoder Paper](https://arxiv.org/abs/2311.02303) has been released on Arxiv, which discloses technique details of multi-task-fine-tuning. @@ -69,6 +82,7 @@ | **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 | | **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | | **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | +| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 | | WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | | GPT-4(zero-shot) | 67.0% | 2023/03 | | PanGu-Coder2 15B | 61.6% | 2023/08 | @@ -84,17 +98,17 @@ ## Articles -[MFT Arxiv paper](https://arxiv.org/abs/2311.02303) +[MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning (KDD2024)](https://arxiv.org/abs/2311.02303) ## Introduction **High Accuracy and efficiency Multi-task Fine-tuning framework for Code LLMs.** -**CodeFuse-MFTCoder** is an open-source project of CodeFuse for accurate and efficient Multi-task Fine-tuning(MFT) on Large Language Models(LLMs), especially on Code-LLMs(large language model for code tasks). +**MFTCoder** is an open-source project of CodeFuse for accurate and efficient Multi-task Fine-tuning(MFT) on Large Language Models(LLMs), especially on Code-LLMs(large language model for code tasks). Moreover, we open source Code LLM models and code-related datasets along with the MFTCoder framework. In MFTCoder, we released two codebases for finetuning Large Language Models: -- ```MFTCoder-accelerate``` is a framework with accelerate and DeepSpeed/FSDP. All tech-stacks are open-source and vibrant. We highly recommend you try this framework and make your fintuning accurate and efficient. +- **```MFTCoder-accelerate```** is a framework with accelerate and DeepSpeed/FSDP. All tech-stacks are open-source and vibrant. We highly recommend you try this framework and make your fintuning accurate and efficient. - ```MFTCoder-atorch``` is based on the [ATorch frameworks](https://github.com/intelligent-machine-learning/dlrover), which is a fast distributed training framework of LLM. The aim of this project is to foster collaboration and share advancements in large language models, particularly within the domain of code development. @@ -121,13 +135,13 @@ The main components of this project include: ## Requirements -To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 11.7) along with the necessary drivers. Additionally, make sure you have installed torch (version 2.0.1). +To begin, ensure that you have successfully installed CUDA (version >= 11.4, preferably 12.1) along with the necessary drivers. Additionally, make sure you have installed torch (version >= 2.1.0). Next, we have provided an init_env.sh script to simplify the installation of required packages. Execute the following command to run the script: ```bash sh init_env.sh ``` -We highly recommend training with flash attention(version >= 2.1.0, preferably 2.3.6), please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention +We highly recommend training with flash attention(version >= 2.3.0), please refer to the following link for installation instructions: https://github.com/Dao-AILab/flash-attention ## Training @@ -135,7 +149,7 @@ As mentioned above, we open source two training frameworks. You could refer to t If you are familiar with open source ```transformers```, ```DeepSpeed``` or ```FSDP```, we highly recommend you try: -🚀🚀 [MFTCoder-accelerate: Accelerate + Deepspeed/FSDP Codebase for MFT(Multi-task Finetuning)](mftcoder_accelerate/README.md) +🚀🚀 [**MFTCoder-accelerate: Accelerate + Deepspeed/FSDP Codebase for MFT(Multi-task Finetuning)**](mftcoder_accelerate/README.md) If you want to explore some new framework like atorch, you could check: @@ -148,23 +162,23 @@ If you want to explore some new framework like atorch, you could check: We are excited to release the following two CodeLLMs trained by MFTCoder, now available on both HuggingFace and ModelScope: -| Model | HuggingFace Links | ModelScope Links | Base Model | Num of examples trained | Batch Size | Seq Length | -|--------------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|------|------------|------------| -| 🔥🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 60万 | 80 | 4096 | -| 🔥🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 60万 | 80 | 4096 | -| 🔥🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | -| 🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | -| 🔥🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 60万 | 80 | 4096 | -| 🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | -| 🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 | - +| Model | HuggingFace Links | ModelScope Links | Base Model | Num of examples trained | Batch Size | Seq Length | +|----------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------|----------------------|-------------------------|------------|------------| +| 🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 600K | 80 | 4096 | +| 🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 600K | 80 | 4096 | +| 🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 600K | 80 | 4096 | +| 🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | +| 🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 600K | 80 | 4096 | +| 🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 1.1 Million | 256 | 4096 | +| 🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 1.1 Million | 256 | 4096 | +| 🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 700K | 128 | 4096 | ## Datasets We are also pleased to release two code-related instruction datasets, meticulously selected from a range of datasets to facilitate multitask training. Moving forward, we are committed to releasing additional instruction datasets covering various code-related tasks. | Dataset | Description | |-----------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [⭐ Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) | Based on open-evol-instruction-80k, filter out low-quality, repeated, and similar instructions to HumanEval, thus get high-quality code instruction dataset. | +| [⭐ Evol-instruction-66k](https://huggingface.co/datasets/codefuse-ai/Evol-instruction-66k) | Based on open-evol-instruction-80k, filter out low-quality, repeated, and similar instructions to HumanEval, thus get high-quality code instruction dataset. | | [⭐ CodeExercise-Python-27k](https://huggingface.co/datasets/codefuse-ai/CodeExercise-Python-27k) | python code exercise instruction dataset | @@ -193,3 +207,13 @@ If you find our work useful or helpful for your R&D works, please feel free to c +## Join-US + +We are the AI Native team within the Platform Technology Business Group at Ant Group, dedicated to the intelligentization of Ant Group's platform engineering. Established for over three years, our team has played a pivotal role in supporting the intelligent operation and maintenance of Ant Group's cloud computing infrastructure. Our mission is to build algorithm services and platforms with a wide user base through world-class technological innovation and impact, supporting the implementation of internal and external products and businesses. +Embracing an innovation-driven ethos, our team not only supports business implementation but also propels technological influence. Over the past three years, we have published more than 20 papers at top conferences like ICLR, NeurIPS, KDD, and ACL. Our innovative business outcomes have earned us two Ant Technology's highest T-Star awards and one SuperMA award from Ant Group. Our open-source project CodeFuse has received 4K stars as of February 2024, and our models have been downloaded over 1.5 million times on Huggingface and Modelscope. + +**We are on the lookout for top talents to join our vibrant team! If you're eager to develop your career in an environment filled with energy, innovation, and a culture of excellence, we welcome you to explore our career opportunities for both campus and experienced hires. Join us and be a part of creating the next milestone in the industry.** + +**Campus Recruitment**: https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 + +**Experienced Hires**: https://talent.antgroup.com/off-campus-position?positionId=1933830 diff --git a/README_cn.md b/README_cn.md index 07fea2a..3102d9f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -41,14 +41,25 @@ - [训练](#训练) - [模型](#模型) - [数据集](#数据集) +- [加入我们](#加入我们) ## 新闻 -🔥🔥🔥 [2024/01/17] **MFTCoder-v0.3.0**发布。新增对Mixtral(MoE), DeepSeek等模型的支持;新增支持FSDP(Fully Sharded Data Parallel);新增Self-paced Loss, 支持多任务收敛均衡。 感兴趣详见微信公众号CodeFuse的文章[MFTCoder 重磅升级v0.3.0发布](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) +🔥🔥🔥 [2024/10/31] **MFTCoder-v0.5**发布,新增**xxpo**模块支持偏好对齐DPO/RPO/ORPO;新增**mpt**和**offline_tokenization**模块支持全量参数的加训;在原本的**pefts**模块(MFT)更新selfpaced收敛均衡技术并更名CoBa。 -🔥🔥🔥 [2024/01/17] 开源了[CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B)模型,在HumanEval pass@1(greedy decoding)上可以达到78.7%。该模型在Big Code榜单的结果近期发布,请关注公众号获取最新信息。 +🔥🔥🔥 [2024/10/31] 我们的论文 [CoBa: Convergence Balancer for Multitask Finetuning of Large Language Models](https://arxiv.org/abs/2410.06741) 已被 EMNLP 2024 接收,可以实现多任务收敛均衡。 -🔥🔥🔥 [2024/01/17] 开源了[CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B)模型,在HumanEval pass@1(greedy decoding)上可以达到56.1%。感兴趣详见微信公众号CodeFuse的文章[MFTCoder提升Mixtral-8x7B混合专家模型的代码能力实践](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) +🔥🔥🔥 [2024/05/20] **MFTCoder-v0.4**发布。新增支持**QLoRA+ DeepSpeed Zero3**, **QLoRA + FSDP**训练模式,可以更好的支持微调更大的模型,比如Qwen1.5-70B等。新增对Qwen2, Qwen2-MoE, Starcoder2, Gemma等模型的支持。 + +🔥🔥🔥 [2024/05/20] 我们的论文 [MFTCoder: Boosting Code LLMs with Multitask Fine-Tuning](https://arxiv.org/abs/2311.02303) 已被 KDD 2024 接收. + +🔥🔥🔥 开源了[CodeFuse-StarCoder2-15B](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B)模型,在HumanEval上可以达到73.2%,多代码语言能力均衡. + +🔥🔥 [2024/01/17] **MFTCoder-v0.3.0**发布。新增对Mixtral(MoE), DeepSeek等模型的支持;新增支持FSDP(Fully Sharded Data Parallel);新增Self-paced Loss, 支持多任务收敛均衡。 感兴趣详见微信公众号CodeFuse的文章[MFTCoder 重磅升级v0.3.0发布](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) + +🔥🔥 [2024/01/17] 开源了[CodeFuse-DeepSeek-33B](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B)模型,在HumanEval pass@1(greedy decoding)上可以达到78.7%。该模型在Big Code榜单的结果近期发布,请关注公众号获取最新信息。 + +🔥🔥 [2024/01/17] 开源了[CodeFuse-Mixtral-8x7B](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B)模型,在HumanEval pass@1(greedy decoding)上可以达到56.1%。感兴趣详见微信公众号CodeFuse的文章[MFTCoder提升Mixtral-8x7B混合专家模型的代码能力实践](https://mp.weixin.qq.com/s/xI3f0iUKq9TIIKZ_kMtcQg) 🔥🔥 [2023/11/07] [MFTCoder论文](https://arxiv.org/abs/2311.02303)在Arxiv公布,介绍了多任务微调的技术细节。 @@ -68,6 +79,7 @@ | **CodeFuse-DeepSeek-33B** | **78.7%** | 2024/01 | | **CodeFuse-CodeLlama-34B** | **74.4%** | 2023/09 | | **CodeFuse-CodeLlama-34B-4bits** | **73.8%** | 2023/09 | +| **CodeFuse-StarCoder2-15B** | **73.2%** | 2023/05 | | WizardCoder-Python-34B-V1.0 | 73.2% | 2023/08 | | GPT-4(zero-shot) | 67.0% | 2023/03 | | PanGu-Coder2 15B | 61.6% | 2023/08 | @@ -117,12 +129,12 @@ ## 环境 -首先, 你需要将CUDA(>=11.4, 推荐11.7)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.0.0) +首先, 你需要将CUDA(>=11.4, 推荐12.1)及其相关驱动安装成功,并确保其工作正常, 并且安装基本的torch(>=2.1.0) 在requirements.txt下固定了几个主要的python包的版本,执行如下脚本即可: ```bash sh init_env.sh ``` -我们强烈建议您安装flash attention(>=2.1.0, 推荐2.3.6), 安装请参考 https://github.com/Dao-AILab/flash-attention +我们强烈建议您安装flash attention(>=2.3.0), 安装请参考 https://github.com/Dao-AILab/flash-attention ## 训练 如果你熟悉大模型训练的各种主流开源资源,例如 ```transformers```, ```DeepSpeed```, ```FSDP```等, 为了用开源项目快速上手高性能微调,我们建议您尝试: @@ -144,11 +156,11 @@ sh init_env.sh | 🔥🔥🔥 CodeFuse-DeepSeek-33B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-DeepSeek-33B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-DeepSeek-33B) | DeepSeek-coder-33B | 60万 | 80 | 4096 | | 🔥🔥🔥 CodeFuse-Mixtral-8x7B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-Mixtral-8x7B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-Mixtral-8x7B) | Mixtral-8x7B | 60万 | 80 | 4096 | | 🔥🔥🔥 CodeFuse-CodeLlama-34B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B) | CodeLlama-34b-Python | 60万 | 80 | 4096 | -| 🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | +| 🔥🔥🔥 CodeFuse-CodeLlama-34B-4bits | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits) | CodeLlama-34b-Python | | | 4096 | | 🔥🔥🔥 CodeFuse-StarCoder-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder-15B) | StarCoder-15B | 60万 | 80 | 4096 | -| 🔥🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | -| 🔥🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 | - +| 🔥🔥🔥 CodeFuse-QWen-14B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B) | Qwen-14b | 110万 | 256 | 4096 | +| 🔥🔥🔥 CodeFuse-CodeGeex2-6B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-CodeGeex2-6B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeGeex2-6B) | CodeGeex2-6B | 110万 | 256 | 4096 | +| 🔥🔥🔥 CodeFuse-StarCoder2-15B | [h-link](https://huggingface.co/codefuse-ai/CodeFuse-StarCoder2-15B) | [m-link](https://modelscope.cn/models/codefuse-ai/CodeFuse-StarCoder2-15B) | Starcoder2-15B | 70万 | 128 | 4096 | @@ -173,3 +185,15 @@ sh init_env.sh } ``` +## 加入我们 + +我们是平台技术事业群AI Native团队,负责蚂蚁蚂蚁集团平台工程的智能化,团队成立3年多以来,支持了蚂蚁集团云计算基础设施智能化运维的升级改造。团队的Mission是,通过世界级的技术创新和影响,构建有广泛用户的算法服务和平台,支撑内外部产品和业务落地。团队秉承创新基因,在支撑业务落地的同时,推动技术影响。3年以来在ICLR、NeurIPS、KDD、ACL等顶会发表论文20余篇,创新业务结果获得两次蚂蚁技术最高奖T-Star,1次蚂蚁集团最高奖SuperMA。开源项目CodeFuse获得4K点赞(2024年2月),Huggingface和modelscope上模型累积下载量超过150万次。 + +**我们正在寻找行业中的佼佼者加入我们的团队!如果您希望在一个充满活力、创新和卓越文化的环境中发展您的职业生涯,欢迎您查看我们的社招&校招机会,加入我们,一起创造下一个行业里程碑。** + +**校招**:https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 + +**社招**:https://talent.antgroup.com/off-campus-position?positionId=1933830 + +## 联系我们 +![img_wx.png](./assets/CodeFuse-AI群.png) diff --git "a/assets/CodeFuse-AI\347\276\244.png" "b/assets/CodeFuse-AI\347\276\244.png" new file mode 100644 index 0000000..4e0c0a1 Binary files /dev/null and "b/assets/CodeFuse-AI\347\276\244.png" differ diff --git a/init_env.sh b/init_env.sh index 7964e99..834b38d 100644 --- a/init_env.sh +++ b/init_env.sh @@ -1,4 +1,4 @@ -pip install torch==2.0.1 && \ +pip install torch==2.1.0 && \ pip install tensorboard==2.11.0 && \ pip install packaging && \ -pip install -r requirements.txt \ No newline at end of file +pip install -r requirements.txt diff --git a/mftcoder_accelerate/README.md b/mftcoder_accelerate/README.md index 649811d..87b4b63 100644 --- a/mftcoder_accelerate/README.md +++ b/mftcoder_accelerate/README.md @@ -7,21 +7,28 @@ [[中文]](README_cn.md) [**English**] ## 1. Updates +🔥 MFTCoder-accelerate now supports DPO/ORPO training through xxpo module. -🔥 MFTCoder-accelerate supports Full-parameters/LoRA using accelerate + FSDP Framework; +🔥 MFTCoder-accelerate now supports continue training through mpt module along with offline_tokenization module. -🔥 MFTCoder-accelerate supports MFT/SFT on more new mainstream open-source base models: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3; +🔥 MFTCoder-accelerate supports MFT with latest implementation of CoBa Loss (selfpaced Loss) for better Convergence Balance. -🔥 MFTCoder-accelerate supports Self-Paced Loss for Convergence Balance; +🔥 MFTCoder-accelerate now support these modes: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, Full-parameter + DeepSpeed ZeRO3, QLoRA + FSDP, Full-parameter + FSDP. -🔥 MFTCoder-accelerate supports Full-parameters/QLoRA/LoRA using accelerate + DeepSpeed Framework; +🔥 MFTCoder-accelerate supports QLoRA + DeepSpeed ZeRO3 and QLoRA + FSDP, which both work for larger models. + +🔥 MFTCoder-accelerate supports MFT/SFT on more new mainstream open-source base models: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3. + +🔥 MFTCoder-accelerate supports Self-Paced Loss for Convergence Balance. + +🔥 MFTCoder-accelerate supports Full-parameters/QLoRA/LoRA using accelerate + DeepSpeed Framework. 🔥 MFTCoder-accelerate supports Multitask Fine-Tuning(MFT), which is able to balance diffenrent tasks in data level. 🔥 MFTCoder-accelerate supports finetuning most of mainstream open-source base models: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen. ## 2. Data Format -### 2.1 Training Data Format +### 2.1 MFT Training Data Format The training data is required to be a uniformed JSONL format, in which each line of data has the following "chatML"-style JSON format. The "chat_rounds" field is required, and other fields can be added or removed based on specific needs. The reason why we selected "chatML" style as our training and inference data format is that "chatML" style is compatible with both "conversation" and "instruction/response" scenarios. @@ -56,7 +63,7 @@ For the keys of roles in "chat_rounds", you could use "system/human/bot" tuple o } ``` -### 2.2 Default Inference Data Format +### 2.2 Default MFTCoder Inference Template Inference data format is the real string format consumed by tokenizers and then LLMs. It is also the string format to which the training data is converted before tokenization. The default inference data format contains strings concatenated by conversation data(system, human and bot contents) in the training data format. It is used as the data "seen"(before tokenization) by the model in training process. @@ -86,6 +93,56 @@ User nth round input ``` When applying inference, you always make your input string end with ```bot\n``` to request the model generating answers. +### 2.3 DPO训练数据格式 +The training data is required to be a uniformed JSONL format, in which each line of data has the following JSON format. The "chosen" and "rejected" fields are required as ```chosen``` and ```rejected``` in DPO training and both includes "chatml-style" contents(only last content of bot differs). +```json +{ + "chosen":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "OK, this code ..." + } + ], + "rejected":[ + { + "role": "system", + "content": "You are a expert in coding and help answer code questions" + }, + { + "role": "human", + "content": "Write a python function of quick sort" + }, + { + "role": "bot", + "content": "Below is the function of quick sort: ..." + }, + { + "role": "human", + "content": "Explain the code" + }, + { + "role": "bot", + "content": "Sorry, I can not answer..." + } + ] +} +``` ## 3. Model Training @@ -113,6 +170,12 @@ mftcoder_accelerate | *pefts* | + *xxpo* + | + *mpt* + | + *offline_tokenization* + | tokenizer | utils @@ -121,7 +184,11 @@ mftcoder_accelerate ``` 我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化, 详见```src```目录下的实现。 -训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` +MFT训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` + +DPO/ORPO训练入口文件是```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py``` + +MPT(全量加训)训练入口文件是```mftcoder_accelerate/src/mpt/mpt_accelerate.py``` 参数配置存储在```mftcoder_accelerate/src/configs```目录下,方便统一管理和更改。 @@ -130,8 +197,13 @@ mftcoder_accelerate cd mftcoder_accelerate/src ``` -You can find the implementations in the ```mftcoder_accelerate/src``` directory. -The entry directory for fine-tuning training is ```mftcoder_accelerate/src```, and the entry file for training is ```mftcoder_accelerate/src/pefts/mft_accelerate.py```. +You can find the implementations in the ```mftcoder_accelerate/src``` directory +The entry file for MFT training is ```mftcoder_accelerate/src/pefts/mft_accelerate.py```. + +The entry file for DPO/ORPO training is ```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py```. + +The entry file for MPT(Continue Training) is ```mftcoder_accelerate/src/mpt/mpt_accelerate.py```. You need finish offline tokenization of your data via ```mftcoder_accelerate/src/run_offline_tokenization.sh```, which is different from the online tokenizaion used in MFT/DPO. + Configurations are stored in the ```mftcoder_accelerate/src/configs``` directory for easy management and modification. **_As a result, before you start training, you should first change your dir by_** @@ -139,7 +211,7 @@ Configurations are stored in the ```mftcoder_accelerate/src/configs``` directory cd mftcoder_accelerate/src ``` -### 3.1 Tokenization +### 3.1 MFT Tokenization During training, we concatenate multi-turn dialogues into the following format (also known as the inference data format mentioned before) and then tokenize it. In default format, ```human\n``` starts the user's input (i.e., prompt),```bot\n``` starts the assistant's output (i.e., response) @@ -175,10 +247,14 @@ DeepSpeed config in accelerate_ds_config.yaml. accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "DeepSpeed" ``` or -DeepSpeed config in command line arguments +DeepSpeed Zero2 config in command line arguments ```bash sh ds_single_launch.sh ``` +DeepSpeed Zero3 config in command line arguments +```bash +sh ds_zero3_single_launch.sh +``` #### Launch via FSDP FSDP config in accelerate_fsdp_config.yaml. @@ -188,7 +264,13 @@ accelerate launch --config_file accelerate_fsdp_config.yaml pefts/mft_accelerate or FSDP config in command line arguments ```bash -sh ds_single_launch.sh +sh fsdp_single_launch.sh +``` + +#### MultiNode Launch +Refer to the deepspeed multi-node launch script below. +```bash +sh ds_multinode_launch.sh ``` #### Traing Arguments @@ -260,6 +342,17 @@ Frequently used arguments are provided in ```configs/***_train_config``` and exp - **role_markers**: {"system": "\system\n", "user": "\human\n", "assistant": "\bot\n} as default(null). You could set your preferred role_markers as the templates startting "system", "user" and "assistant". e.g. {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} +#### CoBa Arguments Configuration +- **coba_warmup_steps**: The number of warm-up steps for CoBa. During the warm-up period, all task weights are equal, and after the warm-up, weights begin to be adjusted dynamically. It is generally recommended to set this close to the total number of validation batches. +- **coba_history_length**: The historical window length of validation loss maintained by CoBa, used to fit the convergence slope at the current step. It is generally recommended to set this between 2 times and 5 times the **coba_warmup_steps**. Typically, the larger this value, the smaller the changes in weights will be. +- **coba_tau**: The temperature coefficient for the Divergence Factor (DF). It is generally set to 5. +- **coba_update_interval**: The frequency at which CoBa updates weights. It is commonly set to 1, meaning weights are updated at every step. +- **coba_sample_valid_num**: The number of validation batches to be sampled by CoBa at each step. Theoretically, when this value equals the total number of validation batches, the fitted convergence slope most closely approximates the actual situation. However, considering computational requirements, it is recommended to set it to 1. + +#### DPO Arguments Configuration +- **xxpo**: preference optimization type, "dpo" or "orpo". +- **beta**: DPO beta, smaller beta allows larger distance between dpo model and ref model. +- **rpo_alpha**: The coefficient of the ```chosen``` NLL loss added to dpo loss. ## 4. Model Usage @@ -328,6 +421,8 @@ beam_num: Set a smaller value such as 1 or 3. ```beam_num=1``` represents greedy If OOM happened,you can reduce parameters such as per_device_train_batch_size and seq_length. Since you are dealing with large models (6B, 13B, 34B, 70B, etc.), you are already using gradient checkpointing technology by default, which significantly reduces GPU memory consumption. However, this may slightly slow down the training speed. +QLoRA + DeepSpeed Zero3 is recommended for larger models to avoid OOM. + #### Q2:install packages Please refer to init_env.sh and requirements.txt We highly recommend you install Flash Attention 2 (flash_attn>=2.1.0, 2.3.6 used by us) first to get memory-efficient and fast training. @@ -339,7 +434,8 @@ CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file accelerate_ds_config.ya ``` #### Q4:Whats is a recommended Distributed Training? -For LoRA/QLoRA, we recommend DeepSpeed(ZeRO2) as the underlying framework, because it is easy and stable to use, moreover it is more compatable for different settings. -And FSDP does not support Quantization(integer type in training). +For LoRA, we recommend DeepSpeed ZeRO2 as the underlying framework, because it is easy and stable to use, moreover it is more compatable for different settings. + +For QLoRA, DeepSpeed ZeRO2 and DeepSpeed ZeRO3 are both good, moreover DeepSpeed ZeRO3 is a good choice for very large models. -For Full-parameter finetuning, FSDP is usually faster, and may help you with very large models by sharding parameters and gradients. \ No newline at end of file +For Full-parameter finetuning, DeepSpeed ZeRO3 and FSDP are faster, and may help you with very large models by sharding parameters and gradients. \ No newline at end of file diff --git a/mftcoder_accelerate/README_cn.md b/mftcoder_accelerate/README_cn.md index c208a4a..39631c5 100644 --- a/mftcoder_accelerate/README_cn.md +++ b/mftcoder_accelerate/README_cn.md @@ -7,20 +7,30 @@ [**中文**] [[English]](README.md) ## 1. 更新 -🔥 MFTCoder-accelerate 新增支持accelerate + FSDP框架, 支持全量微调和LoRA; +🔥 MFTCoder-accelerate 增加了xxpo模块,支持dpo训练。 -🔥 MFTCoder-accelerate 支持最新更多主流开源模型: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3; +🔥 MFTCoder-accelerate 增加了mpt模块,借助offline_tokenization模块,支持全量参数加训。 -🔥 MFTCoder-accelerate 新增self-paced Loss, 用于收敛均衡; +🔥 MFTCoder-accelerate 增加了CoBa Loss的最新实现(原selfpaced Loss), 让收敛均衡更进一步。 -🔥 MFTCoder-accelerate 支持使用accelerate + DeepSpeed框架下支持 全量参数/QLoRA/LoRA微调; +🔥 MFTCoder-accelerate 最新支持的训练模式包括: QLoRA/LoRA + DeepSpeed ZeRO2, QLoRA + DeepSpeed ZeRO3, 全量 + DeepSpeed ZeRO3, QLoRA + FSDP, 全量 + FSDP。 -🔥 MFTCoder-accelerate 在训练中支持了多任务微调MFT, 可以同时平衡多个任务的训练,训练的模型支持多任务推理; +🔥 MFTCoder-accelerate 新增支持QLoRA + DeepSpeed ZeRO3, 支持QLoRA + FSDP, 可以训练更大的模型。 -🔥 MFTCoder-accelerate 在训练中支持多种模型基座: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen等 +🔥 MFTCoder-accelerate 新增支持accelerate + FSDP框架, 支持全量微调和LoRA。 + +🔥 MFTCoder-accelerate 支持最新更多主流开源模型: mistral, mixtral-8x7b(Mixture of Experts), deepseek, chatglm3。 + +🔥 MFTCoder-accelerate 新增self-paced Loss, 用于收敛均衡。 + +🔥 MFTCoder-accelerate 支持使用accelerate + DeepSpeed框架下支持 全量参数/QLoRA/LoRA微调。 + +🔥 MFTCoder-accelerate 在训练中支持了多任务微调MFT, 可以同时平衡多个任务的训练,训练的模型支持多任务推理。 + +🔥 MFTCoder-accelerate 在训练中支持多种模型基座: codellama, llama2, llama, starcoder, codegeex2, chatglm2, qwen等。 ## 2. 数据格式 -### 2.1 训练数据格式 +### 2.1 MFT训练数据格式 训练数据为jsonl格式,每一行的数据格式如下,其中chat_rounds字段是必需的,可以根据实际需求添加或删除其他字段。 可以参考项目中的xxx.jsonl文件。 ```json @@ -76,6 +86,57 @@ """ ``` +### 2.3 DPO训练数据格式 +训练数据为jsonl格式,每一行的数据格式如下,其中chosen字段和rejected字段分别代表偏好对齐中的```chosen```和```rejected```,其内部依然是MFT的chatml格式,并且只有最后一轮对话的bot content不同。 +```json +{ + "chosen":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "好的,这段代码xxx" + } + ], + "rejected":[ + { + "role": "system", + "content": "你是一个智能代码助手,可以回复用户与代码相关的问题" + }, + { + "role": "human", + "content": "写一个快速排序" + }, + { + "role": "bot", + "content": "以下是一个快速排序算法xxxxxx" + }, + { + "role": "human", + "content": "解释一下这段代码" + }, + { + "role": "bot", + "content": "对不起,我不会" + } + ] +} +``` + ## 3. 模型训练 目前支持全量参数(Full-parameters)指令微调、QLoRA指令微调,LoRA指令微调。 @@ -100,6 +161,12 @@ mftcoder_accelerate | *pefts* | + *xxpo* + | + *mpt* + | + *offline_tokenization* + | tokenizer | utils @@ -108,7 +175,11 @@ mftcoder_accelerate ``` 我们将训练中使用的各种组件抽取出来,以便后续的扩展和优化, 详见```src```目录下的实现。 -训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` +MFT训练入口文件是```mftcoder_accelerate/src/pefts/mft_accelerate.py``` + +DPO/ORPO训练入口文件是```mftcoder_accelerate/src/xxpo/xxpo_accelerate.py``` + +MPT(全量加训)训练入口文件是```mftcoder_accelerate/src/mpt/mpt_accelerate.py```. MPT加训需要提前做好数据的tokenziation,通过```mftcoder_accelerate/src/run_offline_tokenization.sh```,你可以将数据通过cpu进行离线的tokenization。这和MFT/DPO中使用的在线tokenziation不同。 参数配置存储在```mftcoder_accelerate/src/configs```目录下,方便统一管理和更改。 @@ -120,7 +191,7 @@ cd mftcoder_accelerate/src ### 3.1 数据tokenization -训练时,我们将多轮对话拼接成如下格式(也是上文中的推理数据格式),然后进行tokenize。 +MFT/DPO训练时,我们将多轮对话拼接成如下格式(也是上文中的推理数据格式),然后进行tokenize。 其中,默认情况下: ```human\n```作为human/user的起始符,```bot\n```作为bot/assistant的起始符,```{EOS_TOKEN}``` 表示eos_token。 @@ -142,19 +213,24 @@ QLoRA通过4-bit的nf4量化,且加入更多adapter,在大幅减少显存消 QLoRA论文指出,该方法可以在一张V100上对33B的模型进行微调,并且性能逼近全量参数微调。 执行如下命令即可进行 Lora/QLora/全量 微调: -#### Launch via Deepspeed +#### Deepspeed 单机启动 DeepSpeed配置在accelerate_ds_config.yaml中。 ```bash accelerate launch --config_file accelerate_ds_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "DeepSpeed" ``` 或者 -DeepSpeed配置在脚本中通过命令行输入。 +DeepSpeed Zero2 配置在脚本中通过命令行输入。 ```bash sh ds_single_launch.sh ``` -#### Launch via FSDP +DeepSpeed Zero3 配置在脚本中通过命令行输入 +```bash +sh ds_zero3_single_launch.sh +``` + +#### FSDP 单机启动 FSDP配置在accelerate_fsdp_config.yaml中。 ```bash accelerate launch --config_file accelerate_fsdp_config.yaml pefts/mft_accelerate.py --train_config configs/xxx_train_config.json --distributed_type "FSDP" @@ -166,6 +242,12 @@ FSDP配置在脚本中通过命令行输入。 sh fsdp_single_launch.sh ``` +#### 多机启动 +多机启动请参考如下deepspeed多机启动脚本 +```bash +sh ds_multinode_launch.sh +``` + #### 训练参数 _**训练需要的参数配置在```configs/*_train_config```中,主要参数说明如下:**_ @@ -202,6 +284,18 @@ _**训练需要的参数配置在```configs/*_train_config```中,主要参数 - **saving_limit**:整数,ckpt存储数量上限, 全量训练必须设置。默认null即不限制数量。 - **role_markers**: null,即使用{"system": "\system\n", "user": "\human\n", "assistant": "\bot\n"}。 你可以自定义 "system", "user" and "assistant"的模板, 用于定制自己的问答或者对话模板,比如 {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} +#### CoBa相关参数配置 +- **coba_warmup_steps**: CoBa的warm-up步数。在warm-up期间,各任务权重相等,warm-up之后,开始动态调整权重。一般建议设置为与valid batch总数量相近即可。 +- **coba_history_length**: CoBa维护的valid loss的历史窗口长度,用于拟合当前步收敛斜率。一般建议设置为2倍**coba_warmup_steps**至5倍**coba_warmup_steps**之间。一般该值越大,权重的变化幅度就会越小。 +- **coba_tau**: 发散因子(DF)的温度系数。一般设置为5即可。 +- **coba_update_interval**: CoBa更新权重的频率。一般设置为1,即每一步都对权重做更新。 +- **coba_sample_valid_num**: CoBa每一步要取的valid batch数。理论上当该值等于valid batch总数量时,拟合出的收敛斜率最逼近真实情况,但考虑到计算需求,建议设置为1。 + +#### DPO 相关参数配置 +- **xxpo**: 偏好对齐方法, "dpo" 或者 "orpo"。 +- **beta**: DPO beta, beta 越小,允许对齐后的dpo模型与ref模型的距离越远。 +- **rpo_alpha**: 加到dop损失的```chosen``` NLL损失的系数,0的话就是原始DPO。 +- ## 4. 模型使用 ### 4.1 权重合并 @@ -256,7 +350,7 @@ print(gen_text) ## 5. FAQ #### 问题1:OOM如何解决? 如果发生OOM,可以缩小per_device_train_batch_size、seq_length等参数来缓解。由于面对的模型普遍较大(6b, 13b, 34b, 70b等)我们已经默认使用gradient_checkpointing技术,可以大幅降低显存占用,但训练速度会稍慢一些。 - +如果是模型太大,可以使用QLoRA + DeepSpeed ZeRO3(配置 zero stage = 3),这个方案可以在卡数足够的情况下,微调更大的模型。 #### 问题2:安装包错误 参考init_env.sh和requirements.txt @@ -276,14 +370,14 @@ CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file pefts/accelerate_ds_con 如果你可以自行安装环境并使用torch>=2.1.1,可以尝试设置参数"attn_implementation"为 "sdpa"。这样会尝试使用transformers兼容的torch.nn.functional.scaled_dot_product_attention。支持的模型还不全面。 #### 问题5:推荐的分布式框架是怎样的? -对于LoRA/QLoRA, 我们推荐使用DeepSpeed作为底层分布式框架,它具有易用性和兼容性好的特点,并且速度很快。 -FSDP 不支持QLoRA, 因为bitsandbytes暂不支持FSDP。 +对于LoRA, 我们推荐使用DeepSpeed Zero2作为底层分布式框架,它具有易用性和兼容性好的特点,并且速度很快, 模型加载模式上类似DDP。 +对于QLoRA, DeepSpeed Zero2 适合中小模型, DeepSpeed Zero3 适合很大的模型。 -对于全量微调,我们推荐使用FSDP, 因为它在全量训练时可以发挥fully sharding的优势,达到更快的训练速度。 +对于全量微调,可以使用DeepSpeed Zero3, 或者FSDP。二者都是Fully Sharding模式,即模型加载平分在每张卡。 #### 问题6:当前支持的模型中,有什么区别 国产大模型比如chatglm2, chatglm3, baichuan2, qwen, aquila2等,使用的是和模型共同发布的modeling_xxx.py. -其它被transformers官方支持的大模型,由于已经升级支持flash attention等,所以全面切换到官方的modeling支持训练,之前的自定义modeling会被deprecated +其它被transformers官方支持的大模型,比如llama, qwen2, starcoder2, mistral等,全面切换到官方的modeling支持训练,之前的自定义modeling会被deprecated。 diff --git a/mftcoder_accelerate/inference/hf_inference.py b/mftcoder_accelerate/inference/hf_inference.py index 16b8933..67f9ba0 100644 --- a/mftcoder_accelerate/inference/hf_inference.py +++ b/mftcoder_accelerate/inference/hf_inference.py @@ -2,91 +2,85 @@ # @author Chaoyu Chen # @date 2024/1/4 # @module hf_inference.py - +""" +# @author qumu +# @date 2023/9/19 +# @module hf_inference.py +""" import os import sys import torch import textwrap -from transformers import ( - AutoConfig, - AutoTokenizer, - AutoModelForCausalLM, - StoppingCriteria, - StoppingCriteriaList -) +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList from peft import PeftModel -def load_model_tokenizer(path, model_type=None, peft_path=None, torch_dtype=torch.bfloat16, quantization=None, - eos_token=None, pad_token=None): +def load_model_tokenizer( + path, + model_type=None, + peft_path=None, + torch_dtype=torch.bfloat16, + quantization=None, + eos_token=None, + pad_token=None, + batch_size=1, +): """ - load model and tokenizer by transfromers + load model and tokenizer by transfromers """ # load tokenizer first tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) tokenizer.padding_side = "left" - config, unused_kwargs = AutoConfig.from_pretrained( - path, - trust_remote_code=True, - return_unused_kwargs=True - ) + config, unused_kwargs = AutoConfig.from_pretrained(path, trust_remote_code=True, return_unused_kwargs=True) print("unused_kwargs:", unused_kwargs) print("config input:\n", config) - # eos token优先级: 1. 用户输入eos_token 2. config中的eos_token_id 3. config中的eos_token + # eos token parsing if eos_token: eos_token = eos_token eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) print(f"eos_token {eos_token} from user input") + elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: + print(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer") + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(eos_token_id) + elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: + print(f"Initial eos_token {tokenizer.eos_token} from tokenizer") + eos_token = tokenizer.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) + elif hasattr(config, "eos_token_id") and config.eos_token_id: + print(f"Initial eos_token_id {config.eos_token_id} from config.json") + eos_token_id = config.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) + elif hasattr(config, "eos_token") and config.eos_token: + print(f"Initial eos_token {config.eos_token} from config.json") + eos_token = config.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) else: - if hasattr(config, "eos_token_id") and config.eos_token_id: - print(f"eos_token_id {config.eos_token_id} from config.json") - eos_token_id = config.eos_token_id - eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) - elif hasattr(config, "eos_token") and config.eos_token: - print(f"eos_token {config.eos_token} from config.json") - eos_token = config.eos_token - eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) - else: - raise ValueError( - "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json") - - # pad token优先级: 1. 用户输入 pad_token 2. config中的pad_token_id 3. config中的pad_token - if pad_token: - pad_token = pad_token - pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) - print(f"pad_token {pad_token} from user input") - else: - if hasattr(config, "pad_token_id") and config.pad_token_id: - print(f"pad_token_id {config.pad_token_id} from config.json") - pad_token_id = config.pad_token_id - pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) - elif hasattr(config, "pad_token") and config.pad_token: - print(f"pad_token {config.pad_token} from config.json") - pad_token = config.pad_token - pad_token_id = tokenizer.convert_tokens_to_ids(config.pad_token) - else: - print(f"pad_token {eos_token} duplicated from eos_token") - pad_token = eos_token - pad_token_id = eos_token_id - - # update tokenizer eos_token and pad_token - tokenizer.eos_token_id = eos_token_id - tokenizer.eos_token = eos_token - tokenizer.pad_token_id = pad_token_id - tokenizer.pad_token = pad_token + raise ValueError( + "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json" + ) + + try: + tokenizer.eos_token = eos_token + tokenizer.eos_token_id = eos_token_id + # set pad_token to be same as eos_token, it is ok because is will be masked out. + tokenizer.pad_token = eos_token + tokenizer.pad_token_id = eos_token_id + except: + print(f"[WARNING]Cannot set tokenizer.eos_token") print(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}") print(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}") - print(tokenizer) + print(type(tokenizer)) base_model = AutoModelForCausalLM.from_pretrained( path, config=config, - load_in_8bit=(quantization == '8bit'), - load_in_4bit=(quantization == '4bit'), + load_in_8bit=(quantization == "8bit"), + load_in_4bit=(quantization == "4bit"), device_map="auto", torch_dtype=torch_dtype, trust_remote_code=True, @@ -114,9 +108,10 @@ def load_model_tokenizer(path, model_type=None, peft_path=None, torch_dtype=torc def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_sample=True, **kwargs): """ - transformers models inference by huggingface + transformers models inference by huggingface """ - inputs = tokenizer(text_list, return_tensors='pt', padding=True, add_special_tokens=False).to("cuda") + # text_list = [tokenizer.apply_chat_template([{"role": "user", "content": text}], tokenize=False) for text in text_list] + inputs = tokenizer(text_list, return_tensors="pt", padding=True, add_special_tokens=False).to("cuda") # inputs["attention_mask"][0][:100] = 0 # print(inputs) print("================================Prompts and Generations=============================") @@ -128,15 +123,15 @@ def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_ do_sample=do_sample, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - **kwargs + **kwargs, ) - gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True) + gen_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True) for i in range(len(text_list)): - print('=========' * 10) - print(f'Prompt:\n{text_list[i]}') - gen_text[i] = gen_text[i].replace(tokenizer.pad_token, '') - print(f'Generation:\n{gen_text[i]}') + print("=========" * 10) + print(f"Prompt:\n{text_list[i]}") + gen_text[i] = gen_text[i].replace(tokenizer.pad_token, "") + print(f"Generation:\n{gen_text[i]}") # print(f"Outputs ids:\n{outputs[i]}") sys.stdout.flush() @@ -155,11 +150,9 @@ def hf_inference(model, tokenizer, text_list, args=None, max_new_tokens=512, do_ # if you use base + adaptor for inference, provide peft_path or left it None for normal inference base_model = "path/to/basemodel" peft_path = None - model, tokenizer = load_model_tokenizer(base_model, - model_type='', - peft_path=peft_path, - eos_token='', - pad_token='') + model, tokenizer = load_model_tokenizer( + base_model, model_type="", peft_path=peft_path, eos_token="", pad_token="" + ) # hf_inference(model, tokenizer, prompts, do_sample=False, num_beams=1, num_return_sequences=1) hf_inference(model, tokenizer, prompts, do_sample=True, temperature=0.8) diff --git a/mftcoder_accelerate/src/configs/selfpaced_train_config.json b/mftcoder_accelerate/src/configs/coba_train_config.json similarity index 78% rename from mftcoder_accelerate/src/configs/selfpaced_train_config.json rename to mftcoder_accelerate/src/configs/coba_train_config.json index 98007e7..63167f1 100644 --- a/mftcoder_accelerate/src/configs/selfpaced_train_config.json +++ b/mftcoder_accelerate/src/configs/coba_train_config.json @@ -5,16 +5,17 @@ "pretrained_model_path": "$MODEL_NAME_OR_PATH", "model_type": "$MODEL_TYPE", "load_raw_dataset": true, - "data_split": "98,2,0", + "data_split": "95,5,0", "padding_mode": "padding", "use_dynamic_padding": true, "tokenize_mode": "sft", "tokenizer_type": "AutoTokenizer", - "weighted_loss_mode": "selfpaced", - "selfpaced_interval": 1, - "selfpaced_history_length": 100, - "selfpaced_sample_valid_num": 1, - "selfpaced_scale_factor": 50, + "weighted_loss_mode": "coba", + "coba_warmup_steps": 100, + "coba_history_length": 200, + "coba_tau": 5, + "coba_update_interval": 1, + "coba_sample_valid_num": 1, "attn_implementation": "flash_attention_2", "seq_length": 4096, "seed": 1234, @@ -23,8 +24,8 @@ "lora_rank": 96, "lora_alpha": 32, "lora_dropout": 0.05, - "per_device_train_batch_size": 2, - "per_device_eval_batch_size": 2, + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, "learning_rate": 5e-5, "min_lr": 5e-6, "weight_decay": 0.1, @@ -42,4 +43,4 @@ "early_stopping": true, "early_stopping_stall_num": 5, "saving_limit": null -} \ No newline at end of file + } \ No newline at end of file diff --git a/mftcoder_accelerate/src/configs/dpo_train_config.json b/mftcoder_accelerate/src/configs/dpo_train_config.json new file mode 100644 index 0000000..5a93db9 --- /dev/null +++ b/mftcoder_accelerate/src/configs/dpo_train_config.json @@ -0,0 +1,34 @@ +{ + "xxpo": "dpo", + "data_paths": "$DATA_PATHS", + "output_dir": "$OUTPUT_DIR", + "tb_dir": "$TensorBoard_DIR", + "pretrained_model_path": "$MODEL_NAME_OR_PATH", + "model_type": "$MODEL_TYPE", + "data_split": "99,1", + "attn_implementation": "flash_attention_2", + "beta": 0.1, + "rpo_alpha": 0.5, + "peft_type": "lora", + "lora_rank": 64, + "lora_alpha": 128, + "lora_dropout": 0.0, + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "tokenizer_type": "AutoTokenizer", + "dataset_num_proc": 1, + "learning_rate": 5e-7, + "weight_decay": 0.01, + "gradient_accumulation_steps": 8, + "lr_scheduler_type": "cosine", + "warmup_steps": 100, + "num_train_epochs": 2, + "seed": 1105, + "max_prompt_length": 2048, + "max_length": 4096, + "logging_steps": 20, + "save_steps": 500, + "eval_steps": 500, + "epoch_checkpointing": false, + "saving_limit": 5 +} \ No newline at end of file diff --git a/mftcoder_accelerate/src/data/blendable_dataset.py b/mftcoder_accelerate/src/data/blendable_dataset.py index 3dd6139..84b9756 100644 --- a/mftcoder_accelerate/src/data/blendable_dataset.py +++ b/mftcoder_accelerate/src/data/blendable_dataset.py @@ -43,7 +43,7 @@ def __init__(self, datasets, weights): # recompute weights weights = self.calc_weights() - + # Build indices. start_time = time.time() assert num_datasets < 255 @@ -63,9 +63,7 @@ def __init__(self, datasets, weights): print( "> RANK {} elapsed time for building blendable dataset indices: " - "{:.2f} (sec)".format( - torch.distributed.get_rank(), time.time() - start_time - ) + "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time) ) def calc_weights(self): @@ -73,7 +71,7 @@ def calc_weights(self): total_cnt = sum(dataset_sample_cnt) weights = np.array([(cnt + 0.0) / total_cnt for cnt in dataset_sample_cnt], dtype=np.float64) return weights - + def __len__(self): return self.size diff --git a/mftcoder_accelerate/src/data/data_utils.py b/mftcoder_accelerate/src/data/data_utils.py index fa79f32..8d168bd 100644 --- a/mftcoder_accelerate/src/data/data_utils.py +++ b/mftcoder_accelerate/src/data/data_utils.py @@ -32,10 +32,7 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - print_rank_0( - " > finished creating indexed dataset in {:4f} " - "seconds".format(time.time() - start_time) - ) + print_rank_0(" > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) return indexed_dataset @@ -53,20 +50,22 @@ def build_train_valid_test_datasets( build_index_mappings=True, shuffle_before_split=False, weighted_loss_mode=None, - ds_weights=[1., 1., 1.], - train_mode='sft', + ds_weights=[1.0, 1.0, 1.0], + train_mode="sft", ): """Build train, valid, and test datasets.""" # Indexed dataset. - assert os.path.exists(data_prefix + "_input_ids.bin"), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" + assert os.path.exists( + data_prefix + "_input_ids.bin" + ), f"Input tokens datafile not found: {data_prefix}_input_ids.bin" # Indexed dataset. input_ids_indexed_dataset = get_indexed_dataset_(data_prefix + "_input_ids", data_impl, skip_warmup) - if train_mode == 'sft': + if train_mode == "sft": loss_mask_indexed_dataset = get_indexed_dataset_(data_prefix + "_loss_mask", data_impl, skip_warmup) else: - print(f'pretrain mode, loss mask is ones') + print(f"pretrain mode, loss mask is ones") loss_mask_indexed_dataset = None total_num_of_documents = input_ids_indexed_dataset.sizes.shape[0] @@ -79,9 +78,7 @@ def print_split_stats(name, index): print_rank_0(" {}:".format(name)) print_rank_0( " document indices in [{}, {}) total of {} " - "documents".format( - splits[index], splits[index + 1], splits[index + 1] - splits[index] - ) + "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) ) print_split_stats("train", 0) @@ -100,11 +97,9 @@ def build_dataset(index, name, ds_weight=1.0): dataset = None if splits[index + 1] > splits[index]: if shuffle_before_split: - documents = shuffle_doc_index[splits[index]:splits[index + 1]] + documents = shuffle_doc_index[splits[index] : splits[index + 1]] else: - documents = np.arange( - start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 - ) + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) dataset = GPT2PromptDataset( name, @@ -130,11 +125,13 @@ def build_dataset(index, name, ds_weight=1.0): return train_dataset, valid_dataset, test_dataset, total_num_of_documents -def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False): +def build_multiple_train_valid_test_datasets( + args, train_valid_test_num_samples, use_shared_fs=True, data_impl="mmap", mmap_warmup=False +): """Build multiple train, valid, and test datasets.""" - data_prefixes = list(args.data_paths[1:-1].split(',')) + data_prefixes = list(args.data_paths[1:-1].split(",")) - data_weights = list(map(float, args.data_weights[1:-1].split(','))) + data_weights = list(map(float, args.data_weights[1:-1].split(","))) print("data weights: ") print(data_weights) use_shared_fs = use_shared_fs @@ -143,7 +140,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, seq_length = args.seq_length # seq_length = args.block_size seed = args.seed - skip_warmup = (not mmap_warmup) + skip_warmup = not mmap_warmup weight_by_num_documents = args.weight_by_num_documents shuffle_before_split = args.shuffle_before_split weighted_loss_mode = args.weighted_loss_mode @@ -183,9 +180,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, factor = 1 if weight_by_num_documents: # gets the number of documents in each data path - get_num_docs_list = lambda datasets: [ - dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets - ] + get_num_docs_list = lambda datasets: [dataset.input_ids_indexed_dataset.sizes.shape[0] for dataset in datasets] train_num_docs, valid_num_docs, test_num_docs = ( get_num_docs_list(train_datasets), get_num_docs_list(valid_datasets), @@ -201,7 +196,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, ) assert sum(train_weights) != 0.0, "found train weights to be 0.0" assert sum(valid_weights) != 0.0, "found valid weights to be 0.0" - + train_weights, train_num_samples = get_normalized_weights_and_num_samples( train_weights, train_valid_test_num_samples[0] ) @@ -265,7 +260,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, if num_tokens: factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length) factor /= sum([1.0 / w for w in train_ds_weights]) / len(train_ds_weights) - + print_rank_0(f"> common denomination factor for CE loss: {factor}") # Blend. @@ -274,7 +269,7 @@ def build_multiple_train_valid_test_datasets(args, train_valid_test_num_samples, i = 0 for ds in train_datasets: ds.update_ds_weight(ds.ds_weight / factor) - print(f'loss weight of dataset {i} after update: {ds.ds_weight}') + print(f"loss weight of dataset {i} after update: {ds.ds_weight}") i += 1 blending_train_dataset = BlendableDataset(train_datasets, train_weights) blending_valid_dataset = None @@ -318,9 +313,7 @@ def get_train_valid_test_split_(splits_string, size): return splits_index -def get_normalized_weights_and_num_samples( - weights: List[float], num_samples: int -) -> Tuple[List[float], List[int]]: +def get_normalized_weights_and_num_samples(weights: List[float], num_samples: int) -> Tuple[List[float], List[int]]: # Normalize weights weight_sum = sum(weights) assert weight_sum > 0.0 @@ -346,12 +339,7 @@ def get_datasets_normalized_weights_and_num_samples( # samples left to feed to the network. weighted_num_samples = [] for weight in weights: - weighted_num_samples.append( - [ - int(math.ceil(val * weight * 1.005)) - for val in num_samples - ] - ) + weighted_num_samples.append([int(math.ceil(val * weight * 1.005)) for val in num_samples]) return weights, weighted_num_samples diff --git a/mftcoder_accelerate/src/data/gpt2_dataset.py b/mftcoder_accelerate/src/data/gpt2_dataset.py index 05aa632..12eeb87 100644 --- a/mftcoder_accelerate/src/data/gpt2_dataset.py +++ b/mftcoder_accelerate/src/data/gpt2_dataset.py @@ -41,7 +41,7 @@ def __init__( use_shared_fs=True, weighted_loss_mode=None, ds_weight=1.0, - train_mode='sft', + train_mode="sft", ): self.name = name @@ -50,9 +50,9 @@ def __init__( self.weighted_loss_mode = weighted_loss_mode self.ds_weight = ds_weight - - self.task_name = data_prefix.split('/')[-1] - + + self.task_name = data_prefix.split("/")[-1] + self.task_id = TASK2ID[self.task_name] # Checks @@ -114,14 +114,10 @@ def __getitem__(self, idx): else: # Otherwise, get the rest of the initial document. - input_ids_list = [ - self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + input_ids_list = [self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] if self.loss_mask_indexed_dataset is not None: - loss_mask_list = [ - self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + loss_mask_list = [self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] else: loss_mask_list = [] @@ -133,16 +129,12 @@ def __getitem__(self, idx): # And finally add the relevant portion of last document. input_ids_list.append( - self.input_ids_indexed_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) + self.input_ids_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) ) if self.loss_mask_indexed_dataset is not None: loss_mask_list.append( - self.loss_mask_indexed_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) + self.loss_mask_indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) ) input_ids = np.concatenate(input_ids_list) @@ -246,18 +238,12 @@ def __getitem__(self, idx): ) else: # Otherwise, get the rest of the initial document. - sample_list = [ - self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) # And finally add the relevant portion of last document. - sample_list.append( - self.indexed_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) - ) + sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) sample = np.concatenate(sample_list) return {"text": np.array(sample, dtype=np.int64)} @@ -313,10 +299,7 @@ def _build_index_mappings( or (not os.path.isfile(sample_idx_filename)) or (not os.path.isfile(shuffle_idx_filename)) ): - print_rank_0( - " > WARNING: could not find index map files, building " - "the indices on rank 0 ..." - ) + print_rank_0(" > WARNING: could not find index map files, building " "the indices on rank 0 ...") # doc-idx. start_time = time.time() doc_idx = _build_doc_idx(documents, num_epochs, np_rng) @@ -338,13 +321,9 @@ def _build_index_mappings( # 我理解这里的num_samples应该是和入参的num_samples重名,这里只是为了计算构建所有索引的长度,从而决定是用int64还是int32 num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length if 2 * (num_samples + 1) < np.iinfo(np.int32).max: - sample_idx = helpers.build_sample_idx_int32( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch - ) + sample_idx = helpers.build_sample_idx_int32(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) else: - sample_idx = helpers.build_sample_idx_int64( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch - ) + sample_idx = helpers.build_sample_idx_int64(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) np.save(sample_idx_filename, sample_idx, allow_pickle=True) print_rank_0( " > elapsed time to build and save sample-idx mapping " @@ -360,7 +339,7 @@ def _build_index_mappings( " > elapsed time to build and save shuffle-idx mapping" " (seconds): {:4f}".format(time.time() - start_time) ) - + torch.distributed.barrier() # TODO: model parallel # This should be a barrier but nccl barrier assumes @@ -370,7 +349,7 @@ def _build_index_mappings( # torch.distributed.all_reduce(counts) # torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) # assert counts[0].item() == torch.distributed.get_world_size( - # group=mpu.get_io_parallel_group() + # group=mpu.get_io_parallel_group() # ) # Load mappings. @@ -381,9 +360,7 @@ def _build_index_mappings( sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") - print_rank_0( - " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) - ) + print_rank_0(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)) print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) print_rank_0(" total number of epochs: {}".format(num_epochs)) diff --git a/mftcoder_accelerate/src/data/helpers.cpython-38-x86_64-linux-gnu.so b/mftcoder_accelerate/src/data/helpers.cpython-38-x86_64-linux-gnu.so deleted file mode 100755 index 6fbc1b7..0000000 Binary files a/mftcoder_accelerate/src/data/helpers.cpython-38-x86_64-linux-gnu.so and /dev/null differ diff --git a/mftcoder_accelerate/src/data/indexed_dataset.py b/mftcoder_accelerate/src/data/indexed_dataset.py index 9a26379..12ea9c2 100644 --- a/mftcoder_accelerate/src/data/indexed_dataset.py +++ b/mftcoder_accelerate/src/data/indexed_dataset.py @@ -44,17 +44,13 @@ def infer_dataset_impl(path): return None else: print(f"Dataset does not exist: {path}") - print( - "Path should be a basename that both .idx and .bin can be appended to get full filenames." - ) + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None def make_builder(out_file, impl, vocab_size=None): if impl == "mmap": - return MMapIndexedDatasetBuilder( - out_file, dtype=__best_fitting_dtype(vocab_size) - ) + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) @@ -62,9 +58,7 @@ def make_builder(out_file, impl, vocab_size=None): def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") - print( - "Path should be a basename that both .idx and .bin can be appended to get full filenames." - ) + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None if impl == "infer": impl = infer_dataset_impl(path) @@ -145,8 +139,7 @@ def read_index(self, path): with open(index_file_path(path), "rb") as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( - "Index file doesn't match expected format. " - "Make sure that --dataset-impl is configured properly." + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." ) version = f.read(8) assert struct.unpack(" 0: - self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 - + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + def add_data(self, data): self.data.append(data) - + def commit(self, archive_name=None): # TODO: streaming cctx = zstandard.ZstdCompressor(level=3) @@ -354,15 +373,18 @@ def commit(self, archive_name=None): if archive_name is None: archive_name = str(int(time.time())) - res = b''.join(map(lambda x: ("%016d" % len(x)).encode('UTF-8') + x, map(lambda x: x.encode('UTF-8'), self.data))) + res = b"".join( + map(lambda x: ("%016d" % len(x)).encode("UTF-8") + x, map(lambda x: x.encode("UTF-8"), self.data)) + ) cdata = cctx.compress(res) - with open(self.out_dir + '/data_' + str(self.i) + '_' + archive_name + '.dat.zst', 'wb') as fh: + with open(self.out_dir + "/data_" + str(self.i) + "_" + archive_name + ".dat.zst", "wb") as fh: fh.write(cdata) self.i += 1 self.data = [] + class JSONArchive: def __init__(self, out_dir): self.out_dir = out_dir @@ -370,17 +392,17 @@ def __init__(self, out_dir): self.data = [] self.i = 0 if os.path.exists(out_dir) and len(os.listdir(out_dir)) > 0: - self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 - + self.i = max(map(lambda x: int(x.split("_")[1].split(".")[0]), os.listdir(out_dir))) + 1 + def add_data(self, data): self.data.append(data) - + def commit(self): cctx = zstandard.ZstdCompressor(level=3) - - cdata = cctx.compress(json.dumps(self.data).encode('UTF-8')) - with open(self.out_dir + '/data_' + str(self.i) + '_' + str(int(time.time())) + '.json.zst', 'wb') as fh: + + cdata = cctx.compress(json.dumps(self.data).encode("UTF-8")) + with open(self.out_dir + "/data_" + str(self.i) + "_" + str(int(time.time())) + ".json.zst", "wb") as fh: fh.write(cdata) self.i += 1 - self.data = [] \ No newline at end of file + self.data = [] diff --git a/mftcoder_accelerate/src/data/multi_task_dataset.py b/mftcoder_accelerate/src/data/multi_task_dataset.py index 1a8612a..63c4b27 100644 --- a/mftcoder_accelerate/src/data/multi_task_dataset.py +++ b/mftcoder_accelerate/src/data/multi_task_dataset.py @@ -2,11 +2,14 @@ # @author Chaoyu Chen # @date 2023/8/18 +Load dataset in a distributed way. """ + import os import json import math import time +import glob import numpy as np import torch from functools import partial @@ -27,12 +30,12 @@ def __init__( self.name = name self.input_dataset = input_dataset - self.num_samples = len(self.input_dataset['input_ids']) + self.num_samples = len(self.input_dataset["input_ids"]) self.seq_length = seq_length self.weighted_loss_mode = weighted_loss_mode self.ds_weight = ds_weight - self.task_name = data_prefix.split('/')[-1] + self.task_name = data_prefix.split("/")[-1] self.task_id = TASK2ID[self.task_name] # Checks @@ -47,8 +50,7 @@ def __getitem__(self, idx): try: # Get the shuffled index. idx = idx % self.num_samples - idx_data = {key: self.input_dataset[key][idx] - for key in self.input_dataset} + idx_data = {key: self.input_dataset[key][idx] for key in self.input_dataset} if self.weighted_loss_mode: idx_data["weight"] = np.array([self.ds_weight], dtype=np.float32) @@ -115,9 +117,7 @@ def __init__(self, datasets, weights, global_num_samples, local_num_samples): print( "> RANK {} elapsed time for building blendable dataset indices: " - "{:.2f} (sec)".format( - torch.distributed.get_rank(), time.time() - start_time - ) + "{:.2f} (sec)".format(torch.distributed.get_rank(), time.time() - start_time) ) def calc_weights(self): @@ -166,7 +166,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, encoder = UniformEncoder(args, args.tokenize_mode) encoder.initializer() - data_prefixes = list(args.data_paths[1:-1].split(',')) + data_prefixes = list(args.data_paths[1:-1].split(",")) splits = [] splits_string = args.data_split @@ -179,7 +179,7 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, while len(splits) < 3: splits.append(0.0) splits = splits[:3] - print(f'data splits: {splits}') + print(f"data splits: {splits}") all_train_datasets = [] all_valid_datasets = [] @@ -195,45 +195,48 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, # 不同数据集在不同文件夹下 for dataset_index in range(len(data_prefixes)): - files = os.listdir(data_prefixes[dataset_index]) + # files = os.listdir(data_prefixes[dataset_index]) + # get all jsonl files and corresponding reading handler + if data_prefixes[dataset_index].endswith(".jsonl"): + files = [data_prefixes[dataset_index]] + else: + files = glob.glob(os.path.join(data_prefixes[dataset_index], "**/*.jsonl"), recursive=True) + cur_dataset_input_ids = [] cur_dataset_loss_mask = [] # support multiple jsonl files under task dir - for file in files: - file_name = data_prefixes[dataset_index] + '/' + file - if os.path.isdir(file_name): - continue - fin = open(file_name, 'r') - print(f'[Global Rank {global_rank}] open file {file_name}') - - if args.padding_mode == 'padding' or args.padding_mode == 'pack': + for file_name in files: + fin = open(file_name, "r") + print(f"[Global Rank {global_rank}] open file {file_name}") + + if args.padding_mode == "padding" or args.padding_mode == "pack" or args.padding_mode == "concat": for i, line in enumerate(fin): # pre-sharding if shard_data and i % world_size != global_rank: continue - data = json.loads(line.rstrip('\n\r')) + data = json.loads(line.rstrip("\n\r")) features, length = encoder.encode(data, verbose=(i < 1)) # features, length = encoder.encode(data) # may have more samples - for idx in range(len(features['input_ids'])): - cur_dataset_input_ids.append(features['input_ids'][idx]) - cur_dataset_loss_mask.append(features['loss_mask'][idx]) + for idx in range(len(features["input_ids"])): + cur_dataset_input_ids.append(features["input_ids"][idx]) + cur_dataset_loss_mask.append(features["loss_mask"][idx]) fin.close() else: i = 0 for line in fin: - data = json.loads(line.rstrip('\n\r')) + data = json.loads(line.rstrip("\n\r")) features, length = encoder.encode(data) # 一个document可能编码不出sample,可能编码出多个sample - for idx in range(len(features['input_ids'])): + for idx in range(len(features["input_ids"])): # post-sharding if shard_data and i % world_size != global_rank: i += 1 continue i += 1 - cur_dataset_input_ids.append(features['input_ids'][idx]) - cur_dataset_loss_mask.append(features['loss_mask'][idx]) + cur_dataset_input_ids.append(features["input_ids"][idx]) + cur_dataset_loss_mask.append(features["loss_mask"][idx]) fin.close() cur_dataset_input_ids = np.array(cur_dataset_input_ids, dtype=np.float32) @@ -249,56 +252,51 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, train_ratio = splits[0] / 100.0 train_num = int(math.ceil(train_ratio * cur_dataset_sample_num)) # split train/valid - cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[: train_num], cur_dataset_input_ids[train_num:] - cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[: train_num], cur_dataset_loss_mask[train_num:] + cur_train_input_ids, cur_valid_input_ids = cur_dataset_input_ids[:train_num], cur_dataset_input_ids[train_num:] + cur_train_loss_mask, cur_valid_loss_mask = cur_dataset_loss_mask[:train_num], cur_dataset_loss_mask[train_num:] local_train_num += train_num - local_valid_num += (cur_dataset_sample_num - train_num) - - cur_train_dataset = { - 'input_ids': cur_train_input_ids, - 'loss_mask': cur_train_loss_mask - } - cur_valid_dataset = { - 'input_ids': cur_valid_input_ids, - 'loss_mask': cur_valid_loss_mask - } + local_valid_num += cur_dataset_sample_num - train_num + + cur_train_dataset = {"input_ids": cur_train_input_ids, "loss_mask": cur_train_loss_mask} + cur_valid_dataset = {"input_ids": cur_valid_input_ids, "loss_mask": cur_valid_loss_mask} print(f"[Global Rank {global_rank}]shape of cur train dataset: {cur_train_dataset['input_ids'].shape}") - print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") + if local_valid_num > 0: + print(f"[Global Rank {global_rank}]shape of cur valid dataset: {cur_valid_dataset['input_ids'].shape}") cur_train_ds = GPT2FromRawDataset( - 'train', + "train", data_prefixes[dataset_index], cur_train_dataset, args.seq_length, weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[0] - ) - cur_valid_ds = GPT2FromRawDataset( - 'valid', - data_prefixes[dataset_index], - cur_valid_dataset, - args.seq_length, - weighted_loss_mode=args.weighted_loss_mode, - ds_weight=splits[1] + ds_weight=splits[0], ) - all_train_datasets.append(cur_train_ds) - all_valid_datasets.append(cur_valid_ds) all_train_datasets_length.append(len(cur_train_ds)) - all_valid_datasets_length.append(len(cur_valid_ds)) - - print(f'[Global Rank {global_rank}]num tokens: {num_tokens}') - print(f'[Global Rank {global_rank}]effective token rate: {effective_token_rate}') + if local_valid_num > 0: + cur_valid_ds = GPT2FromRawDataset( + "valid", + data_prefixes[dataset_index], + cur_valid_dataset, + args.seq_length, + weighted_loss_mode=args.weighted_loss_mode, + ds_weight=splits[1], + ) + all_valid_datasets.append(cur_valid_ds) + all_valid_datasets_length.append(len(cur_valid_ds)) + else: + cur_valid_ds = None + + print(f"[Global Rank {global_rank}]num tokens: {num_tokens}") + print(f"[Global Rank {global_rank}]effective token rate: {effective_token_rate}") num_tokens = [] ds_fn = partial(ds_weights_by_num_docs_sft) - train_loss_weights, valid_loss_weights = ( - ds_fn(all_train_datasets_length), - ds_fn(all_valid_datasets_length), - ) - + train_loss_weights = ds_fn(all_train_datasets_length) print(f"> train loss weights in rank {global_rank}: {train_loss_weights}") - print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") + if all_valid_datasets_length: + valid_loss_weights = ds_fn(all_valid_datasets_length) + print(f"> valid loss weights in rank {global_rank}: {valid_loss_weights}") factor = 1 # calcualte common factor based on token cnt and total sample cnt @@ -306,51 +304,65 @@ def load_dataset_from_jsonl(args, shard_data=False, world_size=1, global_rank=0, factor = sum(num_tokens) / (sum(total_sample_cnt) * args.seq_length) factor /= sum([1.0 / w for w in train_loss_weights]) / len(train_loss_weights) print(f"> common denomination factor for CE loss in rank {global_rank}: {factor}") - + train_sample_weights = [x / sum(all_train_datasets_length) for x in all_train_datasets_length] - valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] print(f"> train sample weights in rank {global_rank}: {train_sample_weights}") - print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") + if all_valid_datasets_length: + valid_sample_weights = [x / sum(all_valid_datasets_length) for x in all_valid_datasets_length] + print(f"> valid sample weights in rank {global_rank}: {valid_sample_weights}") # recompute global_train_num and global_valid_num - + torch.distributed.barrier() device = f"cuda:{local_rank}" - + global_train_num_samples_tensor = torch.tensor(local_train_num, dtype=torch.int32) global_train_num_samples_tensor = global_train_num_samples_tensor.to(device) torch.distributed.all_reduce(global_train_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) global_train_num = global_train_num_samples_tensor.item() - - global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) - global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) - torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) - global_valid_num = global_valid_num_samples_tensor.item() print(f"> global train num in rank {global_rank}: {global_train_num}") - print(f"> global valid num in rank {global_rank}: {global_valid_num}") - + + if local_valid_num > 0: + global_valid_num_samples_tensor = torch.tensor(local_valid_num, dtype=torch.int32) + global_valid_num_samples_tensor = global_valid_num_samples_tensor.to(device) + torch.distributed.all_reduce(global_valid_num_samples_tensor, op=torch.distributed.ReduceOp.SUM) + global_valid_num = global_valid_num_samples_tensor.item() + print(f"> global valid num in rank {global_rank}: {global_valid_num}") + torch.distributed.barrier() - for i in range(len(all_train_datasets)): - print(f'loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}') blending_train_dataset = None if all_train_datasets: + for i in range(len(all_train_datasets)): + print( + f"loss weight of train dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) args.do_train = True for i in range(len(all_train_datasets)): all_train_datasets[i].update_ds_weight(train_loss_weights[i] / factor) - print(f'loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}') - blending_train_dataset = GPT2BlendableDataset(all_train_datasets, train_sample_weights, global_train_num, local_train_num) + print( + f"loss weight of train dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}" + ) + blending_train_dataset = GPT2BlendableDataset( + all_train_datasets, train_sample_weights, global_train_num, local_train_num + ) - for i in range(len(all_train_datasets)): - print(f'loss weight of valid dataset {i} before update in rank {global_rank}: {all_train_datasets[i].ds_weight}') blending_valid_dataset = None if all_valid_datasets: + for i in range(len(all_valid_datasets)): + print( + f"loss weight of valid dataset {i} before update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) args.do_valid = True for i in range(len(all_valid_datasets)): all_valid_datasets[i].update_ds_weight(valid_loss_weights[i] / factor) - print(f'loss weight of valid dataset {i} after update in rank {global_rank}: {all_train_datasets[i].ds_weight}') - blending_valid_dataset = GPT2BlendableDataset(all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num) - + print( + f"loss weight of valid dataset {i} after update in rank {global_rank}: {all_valid_datasets[i].ds_weight}" + ) + blending_valid_dataset = GPT2BlendableDataset( + all_valid_datasets, valid_sample_weights, global_valid_num, local_valid_num + ) + return blending_train_dataset, blending_valid_dataset @@ -359,11 +371,13 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys + sys.exit(1) else: print("Making C++ dataset helpers module successfully.") diff --git a/mftcoder_accelerate/src/data/preprocess_data.py b/mftcoder_accelerate/src/data/preprocess_data.py index 5f6d286..f7226bd 100644 --- a/mftcoder_accelerate/src/data/preprocess_data.py +++ b/mftcoder_accelerate/src/data/preprocess_data.py @@ -8,33 +8,32 @@ import sys import ftfy import glob -print("In preprocess_data.py, sys path:", sys.path) + +# print("In preprocess_data_new.py, sys path:", sys.path) from tokenizer import build_tokenizer -CHAT_COL = 'chat_rounds' -ROLE_COL = 'role' -CONTENT_COL = 'content' +CHAT_COL = "chat_rounds" +ROLE_COL = "role" +CONTENT_COL = "content" -SYSTEM_COL = 'system' -PROMPT_COL = 'prompt' -ANSWER_COL = 'answer' +SYSTEM_COL = "system" +PROMPT_COL = "prompt" +ANSWER_COL = "answer" -TEXT_COL = 'text' +TEXT_COL = "text" -table = {ord(f): ord(t) for f, t in zip( - u',。!?:【】()%#@&1234567890', - u',.!?:[]()%#@&1234567890')} +table = {ord(f): ord(t) for f, t in zip(",。!?:【】()%#@&1234567890", ",.!?:[]()%#@&1234567890")} def content_format(content: str): # Replace non-breaking space with space - content = content.replace('\u202f', ' ').replace('\xa0', ' ') + content = content.replace("\u202f", " ").replace("\xa0", " ") # change chinese punctuation to english ones # text = text.translate(table) - - content += '\n' + # if not content.endswith("\n"): + content += "\n" return content @@ -102,6 +101,13 @@ def is_question_answer_format(data): return False +def is_query_answer_format(data): + if "query" in data and "answer" in data: + return True + else: + return False + + class Encoder(object): tokenizer = None @@ -112,7 +118,7 @@ def initializer(self): # Use Encoder class as a container for global data Encoder.tokenizer = build_tokenizer(self.args) # self.tokenizer = build_tokenizer(self.args) - + def pure_encode(self, content): return Encoder.tokenizer.encode(content, add_special_tokens=False) @@ -126,14 +132,13 @@ def encode(self, text): if len(text_ids) > 0: doc_ids.append(text_ids) if self.args.append_eod: - doc_ids[-1].append(Encoder.tokenizer.eod_id) + doc_ids[-1].append(Encoder.tokenizer.eos_token_id) ids[key] = doc_ids return ids, len(text) class UniformEncoder(Encoder): - - def __init__(self, args, mode='sft'): + def __init__(self, args, mode="sft"): super().__init__(args) self.verbose = False self.mode = mode @@ -149,47 +154,47 @@ def __init__(self, args, mode='sft'): def encode(self, data, verbose=False): self.verbose = verbose - encode_res = { - "input_ids": [], - "loss_mask": [] - } + encode_res = {"input_ids": [], "loss_mask": []} if is_prompt_answer_format(data): - data_type = 'prompt_answer' + data_type = "prompt_answer" elif is_prompt_response_format(data): - data_type = 'prompt_response' + data_type = "prompt_response" elif is_input_output_format(data): - data_type = 'input_output' + data_type = "input_output" elif is_instruction_output_format(data): - data_type = 'instruction_output' + data_type = "instruction_output" elif is_instruction_response_format(data): - data_type = 'instruction_response' + data_type = "instruction_response" elif is_question_response_format(data): - data_type = 'question_response' + data_type = "question_response" elif is_question_answer_format(data): - data_type = 'question_answer' + data_type = "question_answer" + elif is_query_answer_format(data): + data_type = "query_answer" elif is_chatml_format(data): - data_type = 'chatML' + data_type = "chatML" elif is_text_format(data): - data_type = 'text' + data_type = "text" else: raise ValueError( f"data_type does not support" f"please use chatML or prompt/answer, prompt/response, question/response, " - f"instruction/output, input/output, instruction/output or text(only for pretrain)") - + f"instruction/output, input/output, instruction/output or text(only for pretrain)" + ) + length = 0 - if data_type == 'chatML': - for chat in data['chat_rounds']: - length += len(chat['content']) - elif data_type == 'text': - length += len(data['text']) + if data_type == "chatML": + for chat in data["chat_rounds"]: + length += len(chat["content"]) + elif data_type == "text": + length += len(data["text"]) else: - # update key + # update key global PROMPT_COL, ANSWER_COL - PROMPT_COL, ANSWER_COL = tuple(data_type.split('_')) + PROMPT_COL, ANSWER_COL = tuple(data_type.split("_")) length = len(data[PROMPT_COL]) + len(data[ANSWER_COL]) - + for token_res in self._tokenize_fields(data, data_type=data_type): for k, v in token_res.items(): encode_res[k].append(v) @@ -197,30 +202,30 @@ def encode(self, data, verbose=False): return encode_res, length def _tokenize_fields(self, data, data_type): - if self.mode == 'sft': + if self.mode == "sft": if self.args.role_markers: system_marker = self.args.role_markers["system"] user_marker = self.args.role_markers["user"] assistant_marker = self.args.role_markers["assistant"] else: - system_marker = 'system\n' - user_marker = 'human\n' - assistant_marker = 'bot\n' - elif self.mode == 'pretrain': - system_marker = '' - user_marker = '' - assistant_marker = '' + system_marker = "system\n" + user_marker = "human\n" + assistant_marker = "bot\n" + elif self.mode == "pretrain": + system_marker = "" + user_marker = "" + assistant_marker = "" else: raise ValueError(f"tokenize_mode does not support {self.mode}, please use sft or pretrain") - sft_end_marker_ids = [Encoder.tokenizer.eod_id] + sft_end_marker_ids = [Encoder.tokenizer.eos_token_id] # uniform SST,SFT,MFT input_ids = [] loss_mask = [] - if data_type == 'chatML': + if data_type == "chatML": chat = data[CHAT_COL] - if chat[0][ROLE_COL] == 'system': + if chat[0][ROLE_COL] == "system": sys_content_ids = self.pure_encode(system_marker + content_format(chat[0][CONTENT_COL])) chat = chat[1:] input_ids += sys_content_ids @@ -230,15 +235,17 @@ def _tokenize_fields(self, data, data_type): role = r[ROLE_COL] content = r[CONTENT_COL] content = content_format(content) - if (role == 'human' or role == 'user') != (i % 2 == 0): - raise ValueError("Conversation roles must alternate user/assistant/user/assistant/... or human/bot/human/bot/...')") - + if (role == "human" or role == "user") != (i % 2 == 0): + raise ValueError( + "Conversation roles must alternate user/assistant/user/assistant/... or human/bot/human/bot/...')" + ) + # compute loss only for assistant's content and eos token afterward - if role == 'human' or role == 'user': + if role == "human" or role == "user": content_ids = self.pure_encode(user_marker + content + assistant_marker) input_ids += content_ids loss_mask += [0] * len(content_ids) - elif role == 'bot' or role == 'assistant': + elif role == "bot" or role == "assistant" or role == "gpt": content_ids = self.pure_encode(content) + sft_end_marker_ids input_ids += content_ids loss_mask += [1] * len(content_ids) @@ -255,7 +262,7 @@ def _tokenize_fields(self, data, data_type): input_ids += text_ids loss_mask += [1] * len(text_ids) else: - system = data.get(SYSTEM_COL, '') + system = data.get(SYSTEM_COL, "") prompt = data[PROMPT_COL] answer = data[ANSWER_COL] @@ -270,28 +277,28 @@ def _tokenize_fields(self, data, data_type): loss_mask += [0] * len(prompt_ids) + [1] * len(answer_ids) # print(self.mode) - if self.mode == 'pretrain': + if self.mode == "pretrain": # change loss mask to all 1s input_ids = input_ids loss_mask = [1] * len(loss_mask) - elif self.mode == 'sft': + elif self.mode == "sft": # do nothing input_ids = input_ids loss_mask = loss_mask - + if self.verbose: print(f"original data:\n{data}") print(f"decoding back:\n{Encoder.tokenizer.decode(input_ids)}") assert len(input_ids) == len(loss_mask) - if self.args.padding_mode == 'padding': + if self.args.padding_mode == "padding": if len(input_ids) <= self.seq_length: yield self.padding(input_ids, loss_mask) # drop if too long else: yield {} - elif self.args.padding_mode == 'concat': + elif self.args.padding_mode == "concat": input_ids = self.remain_input_ids + input_ids loss_mask = self.remain_loss_mask + loss_mask if len(input_ids) < self.seq_length: @@ -303,15 +310,15 @@ def _tokenize_fields(self, data, data_type): cursor = 0 while cursor + self.seq_length <= len(input_ids): yield { - "input_ids": input_ids[cursor: cursor + self.seq_length], - "loss_mask": loss_mask[cursor: cursor + self.seq_length] + "input_ids": input_ids[cursor : cursor + self.seq_length], + "loss_mask": loss_mask[cursor : cursor + self.seq_length], } cursor = cursor + self.stride self.remain_input_ids = input_ids[cursor:] self.remain_loss_mask = loss_mask[cursor:] assert len(self.remain_input_ids) == len(self.remain_loss_mask) yield {} - elif self.args.padding_mode == 'pack': + elif self.args.padding_mode == "pack": if len(input_ids) > self.seq_length: yield {} elif len(self.remain_input_ids) + len(input_ids) > self.seq_length: @@ -326,19 +333,16 @@ def _tokenize_fields(self, data, data_type): yield {} def padding(self, input_ids, loss_mask): - pad_id = Encoder.tokenizer.pad_id + pad_id = Encoder.tokenizer.pad_token_id assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}" input_ids += [pad_id] * (self.seq_length - len(input_ids)) loss_mask += [0] * (self.seq_length - len(loss_mask)) - return { - "input_ids": input_ids, - "loss_mask": loss_mask - } + return {"input_ids": input_ids, "loss_mask": loss_mask} -def find_jsonl_fnames(inputs): +def find_jsonl_fnames(paths): fnames = [] - for p in inputs.split(","): + for p in paths: if not os.path.isdir(p): if p.endswith(".jsonl"): print(f"loading from {p}") diff --git a/mftcoder_accelerate/src/ds_multinode_launch.sh b/mftcoder_accelerate/src/ds_multinode_launch.sh new file mode 100755 index 0000000..dca0670 --- /dev/null +++ b/mftcoder_accelerate/src/ds_multinode_launch.sh @@ -0,0 +1,44 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2024/5/20 +# Description: # Launch script on Multiple Nodes + +# Run this script on all Nodes. + +# You need to export your number of nodes and number of GPUs per node first. +N_NODE=4 +N_GPU_PER_NODE=8 + +# You need to export $MACHINE_RANK, $MASTER_ADDR, $MASTER_PORT automatically for each Node. + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines $N_NODE \ + --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \ + --use_deepspeed \ + --deepspeed_multinode_launcher 'standard' \ + --zero_stage 2 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'none' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag false \ + --zero3_save_16bit_model false \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank $MACHINE_RANK \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" --distributed_type "deepspeed" \ No newline at end of file diff --git a/mftcoder_accelerate/src/ds_single_launch.sh b/mftcoder_accelerate/src/ds_single_launch.sh index 54f1528..d6c84bb 100755 --- a/mftcoder_accelerate/src/ds_single_launch.sh +++ b/mftcoder_accelerate/src/ds_single_launch.sh @@ -1,11 +1,14 @@ #!/bin/sh # Author: Chaoyu Chen -# Last Modified: 2024/12/11 +# Last Modified: 2023/12/11 # Description: An alternative(Command line) way to launch DeepSpeed training # Launch script on single node N_GPU_PER_NODE=8 +# config path +CONFIG="configs/xxx_train_config.json" + # envs used inside training export OMP_NUM_THREADS=4 export TOKENIZERS_PARALLELISM=False @@ -30,6 +33,6 @@ accelerate launch \ --same_network \ --machine_rank 0 \ --rdzv_backend 'static' \ - pefts/mft_accelerate.py --train_config configs/"xxx_train_config.json" \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ --distributed_type "deepspeed" \ > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/ds_zero3_single_launch.sh b/mftcoder_accelerate/src/ds_zero3_single_launch.sh new file mode 100755 index 0000000..5f581c9 --- /dev/null +++ b/mftcoder_accelerate/src/ds_zero3_single_launch.sh @@ -0,0 +1,38 @@ +#!/bin/sh +# Author: Chaoyu Chen +# Last Modified: 2024/5/20 +# Description: An alternative(Command line) way to launch DeepSpeed training + +# Launch script on single node +N_GPU_PER_NODE=8 + +# config path +CONFIG="configs/xxx_train_config.json" + +# envs used inside training +export OMP_NUM_THREADS=4 +export TOKENIZERS_PARALLELISM=False + +TODAY=$(date +%Y-%m%d-%H%M) + +# accelerate launch --config_file accelerate_ds_config.yaml \ +accelerate launch \ + --num_machines 1 \ + --num_processes $N_GPU_PER_NODE \ + --use_deepspeed \ + --zero_stage 3 \ + --offload_optimizer_device 'cpu' \ + --offload_param_device 'cpu' \ + --gradient_accumulation_steps 1 \ + --gradient_clipping 1.0 \ + --zero3_init_flag true \ + --zero3_save_16bit_model true \ + --main_training_function 'main' \ + --mixed_precision 'bf16' \ + --dynamo_backend 'no' \ + --same_network \ + --machine_rank 0 \ + --rdzv_backend 'static' \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ + --distributed_type "deepspeed" \ + > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/fsdp_single_launch.sh b/mftcoder_accelerate/src/fsdp_single_launch.sh index 1959274..2dc8f89 100755 --- a/mftcoder_accelerate/src/fsdp_single_launch.sh +++ b/mftcoder_accelerate/src/fsdp_single_launch.sh @@ -1,11 +1,19 @@ #!/bin/sh # Author: Chaoyu Chen -# Last Modified: 2024/12/11 +# Last Modified: 2023/12/11 # Description: An alternative(command line) way to launch FSDP training # Launch script on single node N_GPU_PER_NODE=8 +# config path +CONFIG="configs/xxx_train_config.json" + +# fsdp_transformer_layer_cls_to_wrap, choose the DecoderLayer +WRAP_MODULE="LlamaDecoderLayer" + + + # envs used inside training export OMP_NUM_THREADS=4 export TOKENIZERS_PARALLELISM=False @@ -21,7 +29,7 @@ accelerate launch \ --fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \ --fsdp_state_dict_type=FULL_STATE_DICT \ --fsdp_backward_prefetch_policy=BACKWARD_PRE \ - --fsdp_transformer_layer_cls_to_wrap=LlamaDecoderLayer \ + --fsdp_transformer_layer_cls_to_wrap=$WRAP_MODULE \ --fsdp_offload_params=false \ --main_training_function=main \ --mixed_precision=bf16 \ @@ -29,7 +37,7 @@ accelerate launch \ --same_network \ --machine_rank=0 \ --rdzv_backend=static \ - pefts/mft_accelerate.py --train_config configs/"xxx_train_config.json" \ + pefts/mft_accelerate.py --train_config "$CONFIG" \ --distributed_type "fsdp" \ > MFTCoder-training-"$TODAY".log 2>&1 & diff --git a/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py new file mode 100644 index 0000000..82e0f5d --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV2Model, DeepseekV2Config + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV2Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py new file mode 100644 index 0000000..d1d5e88 --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/modeling_deepseek.py @@ -0,0 +1,1925 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV2Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) + + +class DeepseekV2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV2MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class DeepseekV2MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + # save dtype before computation + input_dtype = hidden_states.dtype + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + # keep dtype same after moe forward + return y.to(input_dtype) + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 +class DeepseekV2FlashAttention2(DeepseekV2Attention): + """ + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + # print(f"dtype of hidden_states: {hidden_states.dtype}") + # print(f"dtype of q_proj: {self.q_proj.weight.dtype}") + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV2Attention, + "flash_attention_2": DeepseekV2FlashAttention2, +} + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + # print(f"1. dtype of residual: {residual.dtype}") + + hidden_states = self.input_layernorm(hidden_states) + # print(f"2. dtype of hidden_states before attn: {hidden_states.dtype}") + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + # print(f"3. dtype of hidden_states after attn: {hidden_states.dtype}") + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + # print(f"4. dtype of hidden_states after post layernorm: {hidden_states.dtype}") + hidden_states = self.mlp(hidden_states) + # print(f"5. dtype of hidden_states after mlp: {hidden_states.dtype}") + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2PreTrainedModel(PreTrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2Model(DeepseekV2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py new file mode 100644 index 0000000..d243771 --- /dev/null +++ b/mftcoder_accelerate/src/model/deepseek_v2/tokenization_deepseek_fast.py @@ -0,0 +1,38 @@ +from typing import List, Optional, Union + + +from transformers.models.llama import LlamaTokenizerFast + + +class DeepseekTokenizerFast(LlamaTokenizerFast): + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + token = self._tokenizer.id_to_token(index) + tokens.append(token if token is not None else "") + return tokens + + def _convert_id_to_token(self, index: int) -> Optional[str]: + token = self._tokenizer.id_to_token(int(index)) + return token if token is not None else "" diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp new file mode 100644 index 0000000..8458a9b --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_256.cpp @@ -0,0 +1,198 @@ +#include +#include +#include + +// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_256.cpp +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_batched_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_column_compression_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_batched_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_batched_column_compression_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched_column_compression( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_column_compression_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_old_cuda(vec, mat, mul, scales, zeros); +} + + +void vecquant4matmul_batched_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_old_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_column_compression_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant4matmul_batched_column_compression_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant4matmul_batched_column_compression_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_batched_column_compression_old_cuda(vec, mat, mul, scales, zeros); +} + + + +void vecquant8matmul_batched_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_faster_cuda(vec, mat, mul, scales, zeros); +} + + +void vecquant8matmul_batched_faster_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_faster_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_faster_old_cuda(vec, mat, mul, scales, zeros); +} + +void vecquant8matmul_batched_column_compression_faster_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression_faster( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_faster_cuda(vec, mat, mul, scales, zeros); +} + + +void vecquant8matmul_batched_column_compression_faster_old_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +); + +void vecquant8matmul_batched_column_compression_faster_old( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_batched_column_compression_faster_old_cuda(vec, mat, mul, scales, zeros); +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched", &vecquant8matmul_batched, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_old", &vecquant8matmul_batched_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_faster", &vecquant8matmul_batched_faster, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_faster_old", &vecquant8matmul_batched_faster_old, "Vector 8-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant4matmul_batched_old", &vecquant4matmul_batched_old, "Vector 4-bit old Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul_batched_column_compression", &vecquant8matmul_batched_column_compression, "Vector 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant8matmul_batched_column_compression_old", &vecquant8matmul_batched_column_compression_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant8matmul_batched_column_compression_faster", &vecquant8matmul_batched_column_compression_faster, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant8matmul_batched_column_compression_faster_old", &vecquant8matmul_batched_column_compression_faster_old, "Vector old 8-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant4matmul_batched_column_compression_old", &vecquant4matmul_batched_column_compression_old, "Vector old 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); + m.def("vecquant4matmul_batched", &vecquant4matmul_batched, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant4matmul_batched_column_compression", &vecquant4matmul_batched_column_compression, "Vector 4-bit Batched Quantized Matrix Multiplication (CUDA) with weight's column compressed (desc_act)"); +} diff --git a/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu new file mode 100644 index 0000000..b7932cd --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/cache_autogptq_cuda_kernel_256.cu @@ -0,0 +1,1708 @@ +#define _CRT_SECURE_NO_WARNINGS +#include +#include +#include +#include +#include +#include + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) +// adapted from https://github.com/PanQiWei/AutoGPTQ/blob/main/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu +__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { + unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); + hsum += val; + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (hsum << 16) + : (old & 0xffff0000) | hsum; + old = atomicCAS(address_as_ui, assumed, old); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); +} +__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) { + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} +#endif + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +template +__global__ void VecQuant8BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant4BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + + + +template +__global__ void VecQuant8BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + +__global__ void VecQuant8BatchMatMulKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + + + +__global__ void VecQuant8BatchMatMulKernel_faster_old( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width +); + + +template +__global__ void VecQuant4BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +); + + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + + +__global__ void VecQuant8BatchMatMulKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width +); + + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +); + +const int BLOCKWIDTH = 128; +const int BLOCKHEIGHT8 = 32; +const int BLOCKHEIGHT4 = 16; +const int BLOCKHEIGHT_OLD4 = 128; +//const int BLOCKHEIGHT_OLD8 = 128; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + +void vecquant8matmul_batched_column_compression_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3) * 4; + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_cuda", ([&] { + VecQuant8BatchMatMulColumnCompressionKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width / 4; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + int i_w = (w / 4); + int w_bit = (w % 4) * 8; + + int w_index = (batch_shift * height + h + k) * width / 4 + i_w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xFF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + +void vecquant8matmul_batched_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_cuda", ([&] { + VecQuant8BatchMatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); + +} + +template +__global__ void VecQuant8BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT8 + int h = BLOCKHEIGHT8 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + // if (i >= width * height) { + // return; + // } + int k; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero; + if (zero_width == width) { + zero = zeros[batch_shift * width + w]; + } else { + zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + } + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xFF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h * 4 + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); + } + + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + + + +void vecquant4matmul_batched_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_cuda", ([&] { + VecQuant4BatchMatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT4 + int h = BLOCKHEIGHT4 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + int k; + scalar_t w_tmp; + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero; + if (zero_width == width) { + zero = zeros[batch_shift * width + w]; + } else { + zero = scalar_t(((as_unsigned(zeros[batch_shift * zero_width + z_w]) >> z_mod) & 0xF)); + } + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h * 8 + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + +void vecquant4matmul_batched_column_compression_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3) * 8; + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_cuda", ([&] { + VecQuant4BatchMatMulColumnCompressionKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width / 8; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + int i_w = (w / 8); + int w_bit = (w % 8) * 4; + + int w_index = (batch_shift * height + h + k) * width / 8 + i_w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = ((as_unsigned(mat[w_index]) >> w_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + +void vecquant8matmul_batched_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_old_cuda", ([&] { + VecQuant8BatchMatMulKernel_old<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); +} + + +template +__global__ void VecQuant8BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT8 + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + int k_w = k; + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero = zeros[batch_shift * width + w]; + w_tmp = as_unsigned(mat[w_index]); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + +void vecquant8matmul_batched_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulKernel_faster<<>>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, vec_height, height, width, zero_width + ); +} + + + +__global__ void VecQuant8BatchMatMulKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + //int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + int h = BLOCKWIDTH * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ float blockvec[BLOCKWIDTH]; + int i = width * h + w; + int k; + float w_tmp; + + float weight[BLOCKWIDTH]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + int k_w = k; + int w_index = batch_shift * height * width + i + (k_w * width); + float scale = __half2float(scales[batch_shift * width + w]); + float zero = __half2float(zeros[batch_shift * width + w]); + w_tmp = as_unsigned(mat[w_index]); + weight[k] = scale *(w_tmp-zero); + } + + float res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = __half2float(vec[vec_index]); + } else { + blockvec[tid] = 0; + } + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < vec_height; ++k){ + float temp_res = weight[k]*blockvec[k]; + res += temp_res; + } + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} + + + + +void vecquant8matmul_batched_column_compression_faster_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulColumnCompressionKernel_faster<<>>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, height, width + ); + +} + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + //int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + int h = BLOCKWIDTH * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ float blockvec[BLOCKWIDTH]; + int k; + float w_tmp; + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH; ++k){ + int w_index = (batch_shift * height + h + k) * width + w; + float scale = __half2float(scales[batch_shift * height + h + k]); + float zero = __half2float(zeros[batch_shift * height + h + k]); + w_tmp = mat[w_index]; + weight[k] = scale * (w_tmp-zero); + } + + float res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = __half2float(vec[vec_index]); + } else { + blockvec[tid] = 0; + } + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k]*blockvec[k]; + } + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} + + + +void vecquant8matmul_batched_column_compression_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_batched_column_compression_old_cuda", ([&] { + VecQuant8BatchMatMulColumnCompressionKernel_old<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKWIDTH * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + int w_index = (batch_shift * height + h + k) * width + w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = mat[w_index]; + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + +void vecquant4matmul_batched_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + int zero_width = zeros.size(2); + + dim3 blocks( + (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_old_cuda", ([&] { + VecQuant4BatchMatMulKernel_old<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, vec_height, height, width, zero_width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width, + int zero_width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKHEIGHT_OLD4 + int h = BLOCKHEIGHT_OLD4 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= vec_height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + // i is index of mat of block first row + int i = width * h + w; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){ + int k_w = (k / 2); + int k_bit = (k % 2) * 4; + int w_index = batch_shift * height * width + i + (k_w * width); + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * width + w]; + scalar_t zero = zeros[batch_shift * width + w]; + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * vec_height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h*2 + k < vec_height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + + + +void vecquant4matmul_batched_column_compression_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int height = vec.size(3); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKHEIGHT_OLD4 - 1) / BLOCKHEIGHT_OLD4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_batched_column_compression_old_cuda", ([&] { + VecQuant4BatchMatMulColumnCompressionKernel_old<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), + batch, heads, vec_row, height, width + ); + }) + ); + +} + +template +__global__ void VecQuant4BatchMatMulColumnCompressionKernel_old( + const scalar_t* __restrict__ vec, + const uint8_t* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const scalar_t* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int height, + int width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + // h is index of height with step being BLOCKWIDTH + int h = BLOCKHEIGHT_OLD4 * blockIdx.x; + // w is index of width with step being 1 + int w = BLOCKWIDTH * blockIdx.y + tid; + if (w >= width && tid >= height) { + return; + } + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int k; + scalar_t w_tmp; + + float weight[BLOCKWIDTH]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){ + int k_w = (k / 2); + int k_bit = (k % 2) * 4; + int w_index = (batch_shift * height + h + k) * width + k_w; + if (w_index >= weight_total || w >= width) { + weight[k] = 0; + } else { + scalar_t scale = scales[batch_shift * height + h + k]; + scalar_t zero = zeros[batch_shift * height + h + k]; + w_tmp = ((as_unsigned(mat[w_index]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int vr = 0; vr < vec_row; ++vr){ + res = 0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + if (vec_index < input_total) { + blockvec[tid] = vec[vec_index]; + } else { + blockvec[tid] = 0; + } + + __syncthreads(); + for (k = 0; k < BLOCKWIDTH && h*2 + k < height; ++k){ + // res is the dot product of BLOCKWIDTH elements (part of width) + res += weight[k] * blockvec[k]; + } + // add res to the final result, final matrix shape: (batch, vec_row, width) + int out_index = (batch_shift * vec_row + vr) * width + w; + if (out_index < out_total) { + atomicAdd(&mul[out_index], res); + } + __syncthreads(); + } + } + } +} + + + + + +void vecquant8matmul_batched_faster_old_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); + int vec_height = vec.size(3); + int height = mat.size(2); + int width = mat.size(3); + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulKernel_faster_old<<>>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, vec_height, height, width + ); +} + + +__global__ void VecQuant8BatchMatMulKernel_faster_old( + const half* __restrict__ vec, + const uint8_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, + int vec_height, + int height, + int width +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * vec_height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + const int BLOCKWIDTH_half = BLOCKWIDTH/2; + + int h = BLOCKWIDTH * blockIdx.x; //head_dim, dim=-1 + int w = BLOCKWIDTH * blockIdx.y + tid; //seq-len, +0-256 ,dim=-2 + /* + if (w >= width && tid >= vec_height) { + return; + } + */ + __shared__ half blockvec[BLOCKWIDTH]; //256 + int i = width * h + w; + int k; + + half w_tmp1 = __float2half(0); + half w_tmp2 = __float2half(0); + + half2 weight[BLOCKWIDTH_half]; + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + //int zero_index = batch_shift; + for (k = 0; k < BLOCKWIDTH_half; ++k){ + int w_index1 = batch_shift * height * width + i + (2 * k * width); // [batch,head,h+k, w] + int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width); + int zero_index = batch_shift * width + w; // [batch,head, w] + if (w_index1 >= weight_total || w >= width || (2 * k + h) >= height) { + weight[k] = __float2half2_rn(0); + } else { + float zero_f=__half2float(zeros[zero_index]); + float scale_f= __half2float(scales[zero_index]); + if (w_index2 >= weight_total){ + w_tmp1 = __float2half((as_unsigned(mat[w_index1]) -zero_f)*scale_f); + w_tmp2 = __float2half(0); + weight[k] = __halves2half2(w_tmp1,w_tmp2); + //printf("zero_index is %d w is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,w,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k])); + }else{ + w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1])); + w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2])); + + //weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero,zero)),__halves2half2(scale,scale)); + weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f))); + //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k])); + } + } + } + + + for (int vr = 0; vr < vec_row; ++vr){ + float res=0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + int out_index = (batch_shift * vec_row + vr) * width + w; + if (vec_index < input_total) { + //blockvec[tid] = __half2float(vec[vec_index]);// [batch, head, vr, tid(seq_len dim+)] + blockvec[tid] = vec[vec_index]; + //printf("width is %d height is %d h is %d w is %d vec_index is %d out_index is %d vec_row is %d vec_height is %d,vr is %d tid is %d blockvec is %f\n",width,height, h,w,vec_index,out_index,vec_row,vec_height,vr,tid,blockvec[tid]); + } else { + blockvec[tid] = __float2half(0); + } + __syncthreads(); + if (out_index < out_total) { + for (k = 0; k < BLOCKWIDTH_half; ++k){ + half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1])); + res += __low2float(res2) + __high2float(res2); + } + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} + + +void vecquant8matmul_batched_column_compression_faster_old_cuda( + torch::Tensor vec, // [batch,heads, seq_q, seq_v] + torch::Tensor mat, // [batch,heads, seq_v, head_dim] + torch::Tensor mul, // [batch,heads, seq_q,head_dim] + torch::Tensor scales, // [batch,heads, head_dim] + torch::Tensor zeros +) { + int batch = vec.size(0); + int heads = vec.size(1); + int vec_row = vec.size(2); //ql + int height = mat.size(2); //vl + int width = mat.size(3); //head_dim + + dim3 blocks( + (height + BLOCKWIDTH - 1) / BLOCKWIDTH, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + VecQuant8BatchMatMulColumnCompressionKernel_faster_old<<>>( + (half*) vec.data_ptr(), + (uint8_t*) mat.data_ptr(), + (half*) mul.data_ptr(), + (half*) scales.data_ptr(), + (half*) zeros.data_ptr(), + batch, heads, vec_row, height, width + ); + +} + + +__global__ void VecQuant8BatchMatMulColumnCompressionKernel_faster_old( + const half* __restrict__ vec, // [batch,heads, seq_q, seq_v] + const uint8_t* __restrict__ mat, // [batch,heads, seq_v, head_dim] + half* __restrict__ mul, // [batch,heads, seq_q,head_dim] + const half* __restrict__ scales, // [batch,heads, seq_v] + const half* __restrict__ zeros, + int batch, + int heads, + int vec_row, //seq_q + int height, //seq_v + int width //head_dim +) { + int weight_total = batch * heads * height * width; + int input_total = batch * heads * vec_row * height; + int out_total = batch * heads * vec_row * width; + int tid = threadIdx.x; + int h = BLOCKWIDTH * blockIdx.x; // vl + int w = BLOCKWIDTH * blockIdx.y + tid; //head_dim + block + if (w >= width && tid >= height) { + return; + } + __shared__ half blockvec[BLOCKWIDTH]; + int k; + half w_tmp1 = __float2half(0); + half w_tmp2 = __float2half(0); + int i = width * h + w; + const int BLOCKWIDTH_half = BLOCKWIDTH/2; + half2 weight[BLOCKWIDTH_half]; + + for (int b = 0; b < batch; ++b){ + for (int head = 0; head < heads; ++head){ + int batch_shift = b * heads + head; + //int zero_index = batch_shift; + for (k = 0; k < BLOCKWIDTH_half; ++k){ + int w_index1 = batch_shift * height * width + i + (2 * k) * width; // [batch,head, h+k, w] + int w_index2 = batch_shift * height * width + i + ((2 * k + 1) * width); + int zero_index1 = batch_shift * height + h + 2*k; // [batch,head, w] + int zero_index2 = batch_shift * height + h + 2*k+1; // [batch,head, w] + + if (w_index1 >= weight_total || (2 * k + h)>=height) { + weight[k]=__float2half2_rn(0); + } else{ + //int zero_index = batch_shift + h; // [batch,head, w] + //float scale_f1 = __half2float(scales[zero_index1]); + //float zero_f1 = __half2float(zeros[zero_index1]); + if (w_index2>=weight_total){ + w_tmp1 = __float2half((as_unsigned(mat[w_index1]) - __half2float(zeros[zero_index1]))* __half2float(scales[zero_index1])); + w_tmp2 = __float2half(0); + weight[k] = __halves2half2(w_tmp1,w_tmp2); + //printf("zero_index is %d k is %d w is %d head is %d height is %d width is %d w_index1 is %d w_tmp1 is %f w_tmp2 is %f zero is %f scale is %f low is %f high is %f \n ",zero_index,k,w,head,height, width,w_index1,__half2float(w_tmp1),__half2float(w_tmp2),zero_f,scale_f,__low2float(weight[k]),__high2float(weight[k])); + }else{ + w_tmp1 = __int2half_rn(as_unsigned(mat[w_index1])); + w_tmp2 = __int2half_rn(as_unsigned(mat[w_index2])); + half zero1=zeros[zero_index1]; + half zero2=zeros[zero_index2]; + half scale1=scales[zero_index1]; + half scale2=scales[zero_index2]; + weight[k] = __hmul2(__hsub2(__halves2half2(w_tmp1,w_tmp2), __halves2half2(zero1,zero2)),__halves2half2(scale1,scale2)); + //weight[k] = __hfma2(__halves2half2(w_tmp1,w_tmp2), __float2half2_rn(scale_f), __float2half2_rn(-(scale_f * zero_f))); + //printf("zero_index1 is %d zero_index2 is %d k is %d head is %d w is %d h is %d height is %d width is %d w_index1 is %d w_index2 is %d zero is %f scale is %f low is %f high is %f \n ",zero_index1,zero_index2,k,head,w,h,height, width,w_index1,w_index2,__half2float(zero1),__half2float(scale1),__low2float(weight[k]),__high2float(weight[k])); + } + } + } + + + for (int vr = 0; vr < vec_row; ++vr){ + float res=0; + int vec_index = (batch_shift * vec_row + vr) * height + blockIdx.x * BLOCKWIDTH + tid; + int out_index = (batch_shift * vec_row + vr) * width + w; + + if (vec_index < input_total) { + //blockvec[tid] = __half2float(vec[vec_index]); + blockvec[tid] = vec[vec_index]; + //printf("vec_index is %d out_index is %d vec_row is %d ,vr is %d tid is %d blockvec is %f\n",vec_index,out_index,vec_row,vr,tid,blockvec[tid]); + } else { + blockvec[tid] = __float2half(0); + //blockvec[tid] = 0; + } + __syncthreads(); + if (out_index < out_total) { + for (k = 0; k < BLOCKWIDTH_half; ++k){ + half2 res2 = __hmul2(weight[k],__halves2half2(blockvec[2*k],blockvec[2*k+1])); + res += __low2float(res2) + __high2float(res2); + } + atomicAdd(&mul[out_index], __float2half(res)); + } + __syncthreads(); + } + } + } +} diff --git a/mftcoder_accelerate/src/model/qwen/configuration_qwen.py b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py index 2ccfc92..f8fe2cb 100644 --- a/mftcoder_accelerate/src/model/qwen/configuration_qwen.py +++ b/mftcoder_accelerate/src/model/qwen/configuration_qwen.py @@ -35,6 +35,9 @@ def __init__( intermediate_size=22016, no_bias=True, tie_word_embeddings=False, + use_cache_quantization=False, + use_cache_kernel=False, + softmax_in_fp32=False, **kwargs, ): self.vocab_size = vocab_size @@ -59,6 +62,9 @@ def __init__( self.use_logn_attn = use_logn_attn self.use_flash_attn = use_flash_attn self.no_bias = no_bias + self.use_cache_quantization = use_cache_quantization + self.use_cache_kernel = use_cache_kernel + self.softmax_in_fp32 = softmax_in_fp32 super().__init__( tie_word_embeddings=tie_word_embeddings, **kwargs diff --git a/mftcoder_accelerate/src/model/qwen/cpp_kernels.py b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py new file mode 100644 index 0000000..d9cee70 --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/cpp_kernels.py @@ -0,0 +1,55 @@ +from torch.utils import cpp_extension +import pathlib +import os +import subprocess + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") + +# Check if cuda 11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) +if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 7: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') + +# Build path +srcpath = pathlib.Path(__file__).parent.absolute() +buildpath = srcpath / 'build' +_create_build_dir(buildpath) + +def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3', ], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=1 + ) + +extra_flags = [] + +cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp", + "./cache_autogptq_cuda_kernel_256.cu"] +cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags) diff --git a/mftcoder_accelerate/src/model/qwen/modeling_qwen.py b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py index 05264b4..45c0d16 100644 --- a/mftcoder_accelerate/src/model/qwen/modeling_qwen.py +++ b/mftcoder_accelerate/src/model/qwen/modeling_qwen.py @@ -3,14 +3,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import copy import importlib import math +import pathlib from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator import torch import torch.nn.functional as F import torch.utils.checkpoint -from torch.cuda.amp import autocast +import warnings from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList @@ -35,6 +37,8 @@ SUPPORT_CUDA = torch.cuda.is_available() SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 +SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 + from .configuration_qwen import QWenConfig from .qwen_generation_utils import ( @@ -74,10 +78,10 @@ apply_rotary_emb_func = None rms_norm = None flash_attn_unpadded_func = None - +flash_attn_func = None def _import_flash_attn(): - global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func + global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func try: from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func apply_rotary_emb_func = __apply_rotary_emb_func @@ -98,20 +102,49 @@ def _import_flash_attn(): try: import flash_attn + _flash_attn_func = None if not hasattr(flash_attn, '__version__'): from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func else: if int(flash_attn.__version__.split(".")[0]) >= 2: + if int(flash_attn.__version__.split(".")[1]) >= 1: + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func else: from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func flash_attn_unpadded_func = __flash_attn_unpadded_func + flash_attn_func = _flash_attn_func except ImportError: logger.warn( "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " "https://github.com/Dao-AILab/flash-attention" ) +def quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + +def dequantize_cache_torch(qdata, scale, zero): + data = scale * (qdata - zero) + return data class FlashSelfAttention(torch.nn.Module): def __init__( @@ -151,6 +184,12 @@ def forward(self, q, k, v, attention_mask=None): assert all((i.is_cuda for i in (q, k, v))) batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] + seqlen_out = seqlen_q + + if flash_attn_func is not None and batch_size == 1: + dropout_p = self.dropout_p if self.training else 0 + output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal) + return output q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] cu_seqlens_q = torch.arange( @@ -161,12 +200,13 @@ def forward(self, q, k, v, attention_mask=None): device=q.device, ) - if attention_mask is not None: + if batch_size > 1 and attention_mask is not None: k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) - v = v[indices_k] - if seqlen_q == seqlen_k: + if q.size(0) == v.size(0): q = q[indices_k] cu_seqlens_q = cu_seqlens_k + seqlen_q = seqlen_k + v = v[indices_k] else: cu_seqlens_k = torch.arange( 0, @@ -196,8 +236,8 @@ def forward(self, q, k, v, attention_mask=None): softmax_scale=self.softmax_scale, causal=is_causal, ) - if attention_mask is not None and seqlen_q == seqlen_k: - output = self.pad_input(output, indices_k, batch_size, seqlen_q) + if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: + output = self.pad_input(output, indices_k, batch_size, seqlen_out) else: new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] output = output.view(new_shape) @@ -254,99 +294,100 @@ def __init__(self, config): self.register_buffer("logn_tensor", logn_tensor, persistent=False) self.attn_dropout = nn.Dropout(config.attn_dropout_prob) - - def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) + self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False + self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + cache_dtype = torch.float + if self.bf16: + cache_dtype=torch.bfloat16 + elif config.fp16: + cache_dtype = torch.float16 + self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) + self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) + + if config.use_cache_quantization and config.use_cache_kernel: + # pre check if the support files existing + module_root = pathlib.Path(__file__).parent + src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") + if any(not (module_root/src).is_file() for src in src_files): + warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") + self.cache_kernels = None + else: + try: + from .cpp_kernels import cache_autogptq_cuda_256 + self.cache_kernels = cache_autogptq_cuda_256 + except ImportError: + warnings.warn("Failed to import KV cache kernels.") + self.cache_kernels = None + + def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): + device = query.device + if self.use_cache_quantization: + qk, qk_scale, qk_zero = key + if self.use_cache_kernel and self.cache_kernels is not None: + shape = query.shape[:-1] + (qk.shape[-2],) + attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_faster_old( + query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + qk.transpose(-1, -2).contiguous(), + attn_weights, + qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + # attn_weights = attn_weights.to(query.dtype).contiguous() + else: + key = dequantize_cache_torch(qk, qk_scale, qk_zero) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], - value.size(-1) ** 0.5, - dtype=attn_weights.dtype, - device=attn_weights.device, - ) + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) + attn_weights = attn_weights / (size_temp ** 0.5) - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = registered_causal_mask[ - :, :, key_length - query_length : key_length, :key_length - ] mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) - - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def _upcast_and_reordered_attn( - self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None - ): - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - attn_weights = torch.empty( - bsz * num_heads, - q_seq_len, - k_seq_len, - dtype=torch.float32, - device=query.device, - ) - - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( - -1, dk, k_seq_len - ) - attn_weights = torch.baddbmm( - attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor + if causal_mask is not None: + attn_weights = torch.where( + causal_mask, attn_weights.to(attn_weights.dtype), mask_value ) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = registered_causal_mask[ - :, :, key_length - query_length : key_length, :key_length - ] - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + if self.softmax_in_fp32: + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if attn_weights.dtype != torch.float32: - raise RuntimeError( - "Error with upcasting, attn_weights does not have dtype torch.float32" - ) - attn_weights = attn_weights.type(value.dtype) + attn_weights = attn_weights.type(query.dtype) attn_weights = self.attn_dropout(attn_weights) if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = torch.matmul(attn_weights, value) + if self.use_cache_quantization: + qv, qv_scale, qv_zero = value + if self.use_cache_kernel and self.cache_kernels is not None: + shape = attn_weights.shape[:-1] + (query.shape[-1],) + attn_output = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( + attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + qv.contiguous(), # dtype: int32 + attn_output, + qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + if attn_output.dtype != query.dtype: + attn_output = attn_output.to(query.dtype) + attn_weights = attn_weights.to(query.dtype) + else: + value = dequantize_cache_torch(qv, qv_scale, qv_zero) + attn_output = torch.matmul(attn_weights, value) + else: + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -363,8 +404,7 @@ def _merge_heads(self, tensor, num_heads, attn_head_size): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - registered_causal_mask: Optional[torch.Tensor] = None, + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -373,7 +413,6 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ): - mixed_x_layer = self.c_attn(hidden_states) query, key, value = mixed_x_layer.split(self.split_size, dim=2) @@ -405,20 +444,49 @@ def forward( query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) + if self.use_cache_quantization: + key = quantize_cache_v(key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + value = quantize_cache_v(value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + + if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = (torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2)) + value = (torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2)) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) if use_cache: present = (key, value) else: present = None - if self.use_logn_attn and not self.training: - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] + key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + if key_size > self.seq_length and self.use_logn_attn and not self.training: + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) if ( @@ -428,29 +496,46 @@ def forward( and query.is_cuda ): q, k, v = query, key, value - context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask) - - # b s h d -> b s (h d) - context_layer = context_layer.flatten(2,3).contiguous() - + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) else: + key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + if query.size(1) == key_size: + causal_mask = torch.tril( + torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) + ).view(1, 1, key_size, key_size) + else: + causal_mask = None query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) if ( - registered_causal_mask is None + causal_mask is None and self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32 and not query.is_cuda ): raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) - attn_output, attn_weight = self._attn( - query, key, value, registered_causal_mask, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) + + if not self.use_cache_quantization and SUPPORT_TORCH2: + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) + if causal_mask is not None: + attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( + attn_output, self.num_heads, self.head_dim + ) attn_output = self.c_proj(context_layer) @@ -462,6 +547,8 @@ def forward( and not self.is_fp32 ): raise ValueError("Cannot output attentions while using flash-attn") + elif not self.use_cache_quantization and SUPPORT_TORCH2: + raise ValueError("Cannot output attentions while using scaled_dot_product_attention") else: outputs += (attn_weight,) @@ -487,6 +574,7 @@ def forward(self, hidden_states): output = self.c_proj(intermediate_parallel) return output + class QWenBlock(nn.Module): def __init__(self, config): super().__init__() @@ -508,8 +596,7 @@ def __init__(self, config): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - registered_causal_mask: Optional[torch.Tensor] = None, + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -523,7 +610,6 @@ def forward( attn_outputs = self.attn( layernorm_output, rotary_pos_emb_list, - registered_causal_mask=registered_causal_mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, @@ -557,6 +643,7 @@ class QWenPreTrainedModel(PreTrainedModel): is_parallelizable = False supports_gradient_checkpointing = True _no_split_modules = ["QWenBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -597,6 +684,7 @@ def __init__(self, config): self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.embed_dim = config.hidden_size + self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False self.gradient_checkpointing = False self.use_dynamic_ntk = config.use_dynamic_ntk @@ -622,21 +710,6 @@ def __init__(self, config): self.use_flash_attn = config.use_flash_attn self.is_fp32 = not (config.bf16 or config.fp16) - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - self.registered_causal_mask = None - else: - max_positions = config.max_position_embeddings - self.register_buffer( - "registered_causal_mask", - torch.tril( - torch.ones((max_positions, max_positions), dtype=torch.bool) - ).view(1, 1, max_positions, max_positions), - persistent=False, - ) self.h = nn.ModuleList( [ @@ -721,8 +794,10 @@ def forward( past_length = 0 past_key_values = tuple([None] * len(self.h)) else: - past_length = past_key_values[0][0].size(-2) - + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange( past_length, @@ -750,7 +825,10 @@ def forward( kv_seq_len = hidden_states.size()[1] if past_key_values[0] is not None: # past key values[0][0] shape: bs * seq_len * head_num * dim - kv_seq_len += past_key_values[0][0].shape[1] + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] if self.training or not self.use_dynamic_ntk: ntk_alpha_list = [1.0] @@ -768,11 +846,9 @@ def forward( ntk_alpha = self.get_ntk_alpha(kv_seq_len) ntk_alpha_list.append(ntk_alpha) self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list - - rotary_pos_emb_list = [] - for ntk_alpha in ntk_alpha_list: - rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) - rotary_pos_emb_list.append(rotary_pos_emb) + rotary_pos_emb_list = [ + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + ] hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) @@ -805,7 +881,6 @@ def custom_forward(*inputs): create_custom_forward(block), hidden_states, rotary_pos_emb_list, - self.registered_causal_mask, None, attention_mask, head_mask[i], @@ -817,7 +892,6 @@ def custom_forward(*inputs): hidden_states, layer_past=layer_past, rotary_pos_emb_list=rotary_pos_emb_list, - registered_causal_mask=self.registered_causal_mask, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, @@ -861,11 +935,6 @@ def __init__(self, config): assert ( config.bf16 + config.fp16 + config.fp32 <= 1 ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" - logger.warn( - "Warning: please make sure that you are using the latest codes and checkpoints, " - "especially if you used Qwen-7B before 09.25.2023." - "请使用最新模型和代码,尤其如果你在9月25日前已经开始使用Qwen-7B,千万注意不要使用错误代码和模型。" - ) autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 @@ -927,22 +996,13 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): - token_type_ids = kwargs.get("token_type_ids", None) if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + if input_ids.size(0) == 1: + attention_mask = None else: - position_ids = None + attention_mask = kwargs.get("attention_mask", None) if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -953,9 +1013,7 @@ def prepare_inputs_for_generation( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, "attention_mask": attention_mask, - "token_type_ids": token_type_ids, } ) return model_inputs @@ -1042,7 +1100,6 @@ def chat( query: str, history: Optional[HistoryType], system: str = "You are a helpful assistant.", - append_history: bool = True, stream: Optional[bool] = _SENTINEL, stop_words_ids: Optional[List[List[int]]] = None, generation_config: Optional[GenerationConfig] = None, @@ -1054,6 +1111,10 @@ def chat( assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT if history is None: history = [] + else: + # make a copy of the user's input such that is is left untouched + history = copy.deepcopy(history) + if stop_words_ids is None: stop_words_ids = [] @@ -1091,8 +1152,11 @@ def chat( errors='replace' ) - if append_history: - history.append((query, response)) + # as history is a copy of the user inputs, + # we can always return the new turn to the user. + # separating input history and output history also enables the user + # to implement more complex history management + history.append((query, response)) return response, history @@ -1220,8 +1284,7 @@ def __init__(self, dim, base=10000): self._ntk_alpha_cached = 1.0 self._ntk_alpha_cached_list = [1.0] - def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): - seqlen = max_seq_len + offset + def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) self.inv_freq = 1.0 / ( @@ -1244,10 +1307,10 @@ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): cos, sin = emb.cos(), emb.sin() self._rotary_pos_emb_cache = [cos, sin] - def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) + def forward(self, max_seq_len, ntk_alpha=1.0): + self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha) cos, sin = self._rotary_pos_emb_cache - return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]] + return [cos[:, :max_seq_len], sin[:, :max_seq_len]] def _rotate_half(x): @@ -1259,21 +1322,28 @@ def _rotate_half(x): def apply_rotary_pos_emb(t, freqs): + """ Apply rotary embedding to the first rotary_dim of the iput + + Arguments: + t (tensor(batch_size, seq_len, n_head, head_dim)): + the input embedding/hidden states + freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): + the cached cos/sin position embeddings + """ + rot_dim = freqs[0].shape[-1] cos, sin = freqs + t_float = t.float() if apply_rotary_emb_func is not None and t.is_cuda: - t_ = t.float() - cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] - sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] - output = apply_rotary_emb_func(t_, cos, sin).type_as(t) - return output + # apply_rotary_emb in flash_attn requires cos/sin to be of + # shape (seqlen, rotary_dim / 2) and apply rotary embedding + # to the first rotary_dim of the input + cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] + sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] + return apply_rotary_emb_func(t_float, cos, sin).type_as(t) else: - rot_dim = freqs[0].shape[-1] - cos, sin = freqs - t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] - t_ = t_.float() - t_pass_ = t_pass_.float() - t_ = (t_ * cos) + (_rotate_half(t_) * sin) - return torch.cat((t_, t_pass_), dim=-1).type_as(t) + t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] + t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) + return torch.cat((t_rot, t_pass), dim=-1).type_as(t) class RMSNorm(torch.nn.Module): diff --git a/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py index 4a47c7a..2a526d6 100644 --- a/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py +++ b/mftcoder_accelerate/src/model/qwen/tokenization_qwen.py @@ -27,11 +27,22 @@ # regular texts, the surface forms of special tokens need to be # as different as possible to minimize the impact EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) -SPECIAL_TOKENS = ( - ENDOFTEXT, - IMSTART, - IMEND, -) + EXTRAS +# changed to use actual index to avoid misconfiguration with vocabulary expansion +SPECIAL_START_ID = 151643 +SPECIAL_TOKENS = tuple( + enumerate( + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ), + start=SPECIAL_START_ID, + ) +) +SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: @@ -42,6 +53,7 @@ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: for token, rank in (line.split() for line in contents.splitlines() if line) } + class QWenTokenizer(PreTrainedTokenizer): """QWen tokenizer.""" @@ -51,20 +63,35 @@ def __init__( self, vocab_file, errors="replace", + extra_vocab_file=None, **kwargs, ): super().__init__(**kwargs) - self.errors = errors # how to handle errors in decoding + # how to handle errors in decoding UTF-8 byte sequences + # use ignore if you are in streaming inference + self.errors = errors - self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int] + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] self.special_tokens = { token: index - for index, token in enumerate( - SPECIAL_TOKENS, start=len(self.mergeable_ranks) - ) + for index, token in SPECIAL_TOKENS } + # try load extra vocab from file + if extra_vocab_file is not None: + used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) + extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) + for token, index in extra_mergeable_ranks.items(): + if token in self.mergeable_ranks: + logger.info(f"extra token {token} exists, skipping") + continue + if index in used_ids: + logger.info(f'the index {index} for extra token {token} exists, skipping') + continue + self.mergeable_ranks[token] = index + # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this + enc = tiktoken.Encoding( "Qwen", pat_str=PAT_STR, @@ -89,7 +116,7 @@ def __init__( def __getstate__(self): # for pickle lovers state = self.__dict__.copy() - del state['tokenizer'] + del state["tokenizer"] return state def __setstate__(self, state): @@ -103,7 +130,6 @@ def __setstate__(self, state): ) self.tokenizer = enc - def __len__(self) -> int: return self.tokenizer.n_vocab @@ -126,13 +152,17 @@ def convert_tokens_to_ids( ids.append(self.mergeable_ranks.get(token)) return ids - def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + def _add_tokens( + self, + new_tokens: Union[List[str], List[AddedToken]], + special_tokens: bool = False, + ) -> int: if not special_tokens and new_tokens: - raise ValueError('Adding regular tokens is not supported') + raise ValueError("Adding regular tokens is not supported") for token in new_tokens: surface_form = token.content if isinstance(token, AddedToken) else token - if surface_form not in SPECIAL_TOKENS: - raise ValueError('Adding unknown special tokens is not supported') + if surface_form not in SPECIAL_TOKENS_SET: + raise ValueError("Adding unknown special tokens is not supported") return 0 def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: diff --git a/mftcoder_accelerate/src/model/qwen/tokenizer_config.json b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json new file mode 100644 index 0000000..9c37cac --- /dev/null +++ b/mftcoder_accelerate/src/model/qwen/tokenizer_config.json @@ -0,0 +1,10 @@ +{ + "model_max_length": 8192, + "tokenizer_class": "QWenTokenizer", + "auto_map": { + "AutoTokenizer": [ + "tokenization_qwen.QWenTokenizer", + null + ] + } +} diff --git a/mftcoder_accelerate/src/mpt/mpt_accelerate.py b/mftcoder_accelerate/src/mpt/mpt_accelerate.py new file mode 100644 index 0000000..5d187c9 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_accelerate.py @@ -0,0 +1,494 @@ +""" +# @author Chaoyu Chen +# @date 2024/6/1 +# @module mpt_accelerate.py + +Accelerate + DeepSpeed + Full-parameter + Multi-task + Pre-training/Continue Training/Finetuning + +Entry +""" + +import os +import sys +import argparse +import math +import logging +import json +import time +from tqdm.auto import tqdm +import transformers +import numpy as np +import torch +from torch import nn +from dataclasses import dataclass +from datasets import Dataset +import datasets +from torch.utils.data import DataLoader +from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + set_seed, + BitsAndBytesConfig, + get_scheduler, +) + +from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration +from accelerate.logging import get_logger +from datetime import timedelta +from accelerate.utils import InitProcessGroupKwargs +from transformers.optimization import Adafactor + +# insert src as import path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +sys.path.insert(0, parent_dir) + +from tokenizer import build_tokenizer +from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper +from data.data_utils import load_dataset_from_bin +from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK +from mpt.mpt_trainer import MptTrainer +from mpt.mpt_arguments import MptTrainArgs +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS + + +logger = get_logger(__name__) + + +def get_task_mask(args, task_id): + task_num = len(TASK2ID) + task_mask = torch.zeros(task_id.shape[0], task_num) + task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 + + return task_mask + + +def get_attention_mask_and_position_ids(data): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + attention_mask = torch.ones((batch_size, seq_length), device=data.device) + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data).clone() + + return attention_mask, position_ids + + +@dataclass +class DataCollatorForMFTDataset(object): + args: None + + def __call__(self, instances): + (input_ids, loss_mask, weights, task_id) = tuple( + [instance.get(key, None) for instance in instances] + for key in ("input_ids", "loss_mask", "weight", "task_id") + ) + + result_batch = {} + """ + outputs = model( + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']), + # labels=(batch['labels'], batch['loss_mask']), + position_ids=batch['position_ids']) + """ + + # if loss_mask is not None: + loss_mask = torch.tensor(np.array(loss_mask)).long() + last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) + if self.args.use_dynamic_padding: + # get last non-padding position + max_pos = last_one_pos.max().item() + 1 + else: + max_pos = loss_mask.shape[-1] + + if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack": + # 兼容sst + pack tokenization, 最后一位是脏数据,需要去掉 + result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous() + result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous() + else: + result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + # print(f"shape of input_ids: {input_ids.shape}") + result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous() + result_batch["labels"] = input_ids[:, 1:max_pos].contiguous() + + # Get the masks and position ids. + + # if you want to be compatible with non-gpt models, something you can do here + if self.args.model_type in ["antglm"]: + (result_batch["attention_mask"], result_batch["position_ids"]) = get_attention_mask_and_position_ids( + data=result_batch["input_ids"] + ) + elif self.args.model_type in ["mixtral", "mtx-qwen2", "qwen2_moe"]: + batch_size, seq_length = result_batch["input_ids"].shape + # bsz * seq_length + range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1) + # attention_mask for padding tokens + attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long() + result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None + else: + # For decoder-only models, transformers will create them. + result_batch["attention_mask"], result_batch["position_ids"] = None, None + + if task_id is not None: + task_id = torch.tensor(np.array(task_id)) + result_batch["task_mask"] = get_task_mask(self.args, task_id) # bsz * task_num + result_batch["task_id"] = task_id + + return result_batch + + +def pprint_args(args, accelerator): + # 计算所有键的最大字符串长度 + max_key_length = max(len(str(key)) for key in vars(args).keys()) + + message = "" + message += "====" * 60 + "\n" + message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" + message += "====" * 60 + "\n" + accelerator.print(message) + accelerator.print("GPU: {}".format(torch.cuda.current_device())) + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_config", type=str, default=None) + + parser.add_argument("--data_paths", type=str, default=None) + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--tb_dir", type=str, default=None) + parser.add_argument("--pretrained_model_path", type=str, default=None) + parser.add_argument("--micro_batch_size", type=int, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--distributed_type", type=str, default="deepspeed") + + parsed = parser.parse_args() + # get json configs + with open(parsed.train_config, "r") as f: + train_config = json.load(f) + + # parse args from cofig.json + # args = argparse.Namespace(**train_config) + args = MptTrainArgs(**train_config) + + # override args by cli arguments + if parsed.data_paths: + args.data_paths = parsed.data_paths + if parsed.output_dir: + args.output_dir = parsed.output_dir + if parsed.tb_dir: + args.tb_dir = parsed.tb_dir + if parsed.pretrained_model_path: + args.pretrained_model_path = parsed.pretrained_model_path + args.vocab_file = parsed.pretrained_model_path + if parsed.micro_batch_size: + args.per_device_train_batch_size = parsed.micro_batch_size + args.per_device_eval_batch_size = parsed.micro_batch_size + if parsed.model_type: + args.model_type = parsed.model_type + + args.distributed_type = parsed.distributed_type + + # refactor args + + args.vocab_file = args.pretrained_model_path + + args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]" + + # generate TASK2ID, ID2TASK + generate_task_id(args.data_paths) + + if args.weighted_loss_mode == "coba": + args.task_weights = [1.0] * len(ID2TASK) + elif args.task_weights is not None: + args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")] + assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths" + else: + args.task_weights = [1.0] * len(ID2TASK) + + return args + + +def main(): + t0 = time.time() + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["HF_HUB_OFFLINE"] = "false" + # get input args, set TASK2ID, ID2TASK, refactor args + args = prepare_args() + + # fix randomness + if args.seed is not None: + set_seed(args.seed) + + # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + + if args.distributed_type and args.distributed_type.lower() == "fsdp": + fsdp_plugin = FullyShardedDataParallelPlugin( + # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + limit_all_gathers=True, + sync_module_states=True, + use_orig_params=True, + cpu_offload=False, + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + else: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + + # print key infos + accelerator.print("In mft_accelerate.py, sys path:", sys.path) + accelerator.print(f"transformers.__version__: {transformers.__version__}") + + # get world_size + args.world_size = accelerator.num_processes + + # backup args + pprint_args(args, accelerator) + if accelerator.is_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest + latest = None + if os.path.exists(os.path.join(args.output_dir, "latest")): + with open(os.path.join(args.output_dir, "latest"), "r") as fl: + latest = json.load(fl) + accelerator.print(f"[INFO] Existing latest: {latest}") + + if args.auto_resume and args.resume_from_checkpoint is None and latest: + args.resume_from_checkpoint = latest["latest_ckpt"] + + # logger + logging.basicConfig( + format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + # compile Cpp helper + compile_helper() + time.sleep(10) + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # get global_rank and local rank for current process + global_rank = accelerator.process_index + local_rank = accelerator.local_process_index + print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") + + # TASK2ID, ID2TASK + # generate_task_id(args.data_paths) + + # multi task blendable dataset(sharded) + if args.load_raw_dataset: + print_rank_0("> load raw jsonl dataset") + train_dataset, valid_dataset = load_dataset_from_jsonl( + args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank + ) + else: + print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset") + train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args) + + t1 = time.time() + logger.info(f"dataset loading time: {t1 - t0:.4f}") + + # cuda memory + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + max_memory = f"{free_in_GB - 2}GB" + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + accelerator.print("max memory: ", max_memory, n_gpus) + + # # 是否要加入新的special tokens + # num_added_toks = tokenizer.tokenizer.add_special_tokens(["", ""]) + # accelerator.print("We have added", num_added_toks, "tokens") + # accelerator.print(f"role marker tokens {tokenizer.convert_tokens_to_ids('')} {tokenizer.convert_tokens_to_ids('')}, resized tokenizer_size: {len(tokenizer)}") + + # creating model + ModelClass = MODEL_TYPES[args.model_type] + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + attn_implementation=args.attn_implementation, + torch_dtype=torch.bfloat16, + ) + else: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + torch_dtype=torch.bfloat16, + ) + + # build a tokenizer for possible resizing or saving + tokenizer = build_tokenizer(args) + # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, + # i.e. the length of the tokenizer. + # 如果新增special tokens, 需要resize input embedding 和output embedding + # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) + + model.gradient_checkpointing_enable() + + if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1: + # saving_limit is set automatically if needed + args.saving_limit = 2 + accelerator.print( + "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2" + ) + + t2 = time.time() + if accelerator.is_main_process: + logging.info(f"model loading time: {t2 - t1:.4f}") + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model + # load balance for moe training + if hasattr(model.config, "output_router_logits"): + model.config.output_router_logits = True + model_config = model.config + accelerator.print(model.config) + + # dataloader + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_train_batch_size, + pin_memory=True, + drop_last=True, + ) + if valid_dataset: + valid_dataloader = DataLoader( + valid_dataset, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_eval_batch_size, + pin_memory=True, + drop_last=True, + ) + else: + valid_dataloader = None + + # optimizer + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED") + # from deepspeed.ops.adam import FusedAdam as Adam + # adam_optimizer = Adam + adam_optimizer = torch.optim.AdamW + elif accelerator.distributed_type == DistributedType.FSDP: + accelerator.print("DISTRIBUTED TRAINING USING FSDP") + model = accelerator.prepare(model) + adam_optimizer = torch.optim.AdamW + else: + raise ValueError("Only support DeepSpeed and FSDP") + + optimizer = adam_optimizer( + model.parameters(), + weight_decay=args.weight_decay, + lr=args.learning_rate, + betas=(0.9, 0.999), + ) + # for group in optimizer.param_groups: + # group.setdefault("initial_lr", group["lr"]) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0: + args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes + accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}") + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep} + ) + # prepare all + if accelerator.distributed_type == DistributedType.DEEPSPEED: + if valid_dataloader: + (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, valid_dataloader, optimizer, lr_scheduler + ) + else: + (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, optimizer, lr_scheduler + ) + + # prepare all except model, which is prepared before + elif accelerator.distributed_type == DistributedType.FSDP: + if valid_dataloader: + (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, valid_dataloader, lr_scheduler + ) + else: + (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) + print(model.device) + accelerator.print(model) + # accelerator.print(model.config) + + # Recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterward we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # zero 3 flag + is_ds_zero_3 = False + if getattr(accelerator.state, "deepspeed_plugin", None): + is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 + accelerator.print(f"DEEPSPEED plugin: {accelerator.state.deepspeed_plugin}") + elif getattr(accelerator.state, "fsdp_plugin", None): + accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}") + + trainer = MptTrainer( + accelerator=accelerator, + model=model, + model_config=model_config, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + num_update_steps_per_epoch=num_update_steps_per_epoch, + total_train_dataset_size=len(train_dataset), + args=args, + ) + trainer.accelerate_train() + + +if __name__ == "__main__": + main() diff --git a/mftcoder_accelerate/src/mpt/mpt_arguments.py b/mftcoder_accelerate/src/mpt/mpt_arguments.py new file mode 100644 index 0000000..8045421 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_arguments.py @@ -0,0 +1,161 @@ +""" +# @author Chaoyu Chen +# @date 2024/6/1 + +MPT training arguments +""" + +from dataclasses import dataclass, asdict +from typing import List, Union + + +@dataclass +class MptTrainArgs: + # train data paths on shared FS + data_paths: Union[str, List[str]] + + # output dir for saving adaptors in peft or full ckpts in full-parameter training + output_dir: str + + # tensorboard dir for saving tensorboard logs + tb_dir: str + + # pretrained_model_path, on which is the model you want to train + pretrained_model_path: str + + # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2 + model_type: str + + # load from raw jsonl file or tokenized binary file + load_raw_dataset: bool = True + + # weights of loss calculation for each task, None means equal weights + task_weights: Union[None, str] = None + + # weights of data sampling, leave it None + data_weights: Union[None, str] = None + + # hf loading model low_cpu_mem_usage + low_cpu_mem_usage: bool = True + + # train/valid/test split + data_split: str = "98,2,0" + + # padding or pack or concat + padding_mode: str = "padding" + + # sft or sst + tokenize_mode: str = "sft" + + # case3 or case4 + weighted_loss_mode: str = "case3" + + # mircro train batch size + per_device_train_batch_size: int = 8 + + # micro eval batch size, always same as micro train batch size + per_device_eval_batch_size: int = 8 + + # HF AutoTokenizer is supported, maybe more types + tokenizer_type: str = "AutoTokenizer" + + # initial lr + learning_rate: float = 5e-5 + + # minimum lr + min_lr: float = 5e-6 + + # weight decay + weight_decay: float = 0.01 + + # gradient_accumulation_steps + gradient_accumulation_steps: int = 1 + + # lr_scheduler_type + lr_scheduler_type: str = "cosine" + + # num_warmup_steps + num_warmup_steps: Union[int, float] = 0.05 + + # num_train_epochs + num_train_epochs: int = 4 + + # seed for reproducing + seed: int = 1234 + + # seq_length, context length + seq_length: int = 4096 + + # path of adaptor which is resumed from, None for not resuming training + resume_from_checkpoint: Union[None, str] = None + + # auto resume from latest ckpt if job restarted + auto_resume: bool = True + + # num of steps for logging training loss + log_interval: int = 10 + + # num of steps for saving ckpt + checkpointing_steps: int = 100 + + # num of steps for evaluation(eval_loss), better same as checkpointing steps + evaluation_steps: int = 100 + + # max train steps, if None, depends on num_train_epochs + max_train_steps: Union[None, int] = None + + # if checkpointing every epoch, maybe True in sst + epoch_checkpointing: bool = False + + # save transformers model(safetensors) + save_transformers_model: bool = False + + # shuffle before train/valid split + shuffle_before_split: bool = True + + # DDP random sampler + use_random_sampler: bool = True + + # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point + early_stopping: bool = True + early_stopping_stall_num: int = 5 + + # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota. + saving_limit: Union[None, int] = None + + # if dynamic padding + use_dynamic_padding: bool = True + + # warm-up steps for CoBa, recommand the number of valid batches + coba_warmup_steps: int = 100 + # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps] + coba_history_length: int = 200 + # temperature for divergence factor in CoBa + coba_tau: int = 5 + # iteration interval of update per task train weight in CoBa + coba_update_interval: int = 1 + # the number of mini valid batches sampled at each updated iteration interval + coba_sample_valid_num: int = 1 + + # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} + attn_implementation: str = "flash_attention_2" + + # role markers, which are prompt template before each role: system, user and assistant + # role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} + role_markers: Union[None, dict] = None + + distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + + # legacy, leave them + use_xformers: bool = True + trust_remote_code: bool = True + weight_by_num_documents: bool = True + make_vocab_size_divisible_by: int = 32 + model_parallel_size: int = 1 + use_slow_tokenizer: bool = False + world_size: int = 8 + + def dict(self): + return {k: str(v) for k, v in asdict(self).items()} diff --git a/mftcoder_accelerate/src/mpt/mpt_trainer.py b/mftcoder_accelerate/src/mpt/mpt_trainer.py new file mode 100644 index 0000000..b5e2da8 --- /dev/null +++ b/mftcoder_accelerate/src/mpt/mpt_trainer.py @@ -0,0 +1,606 @@ +""" +# @author qumu +# @date 2024/6/6 +# @module mpt_trainer.py + +MPT/MCT/MFT Full-parameter Trainer +""" + +import gc +import os +import sys +import threading +import argparse +import math +import logging +import json +import time +import transformers +import numpy as np +import psutil +import shutil +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from typing import List, Optional, Tuple, Union +from tqdm.auto import tqdm +from accelerate.logging import get_logger +from accelerate import Accelerator +from transformers import set_seed + +# sys.path.append("..") +from utils.common_utils import generate_task_id, TASK2ID, ID2TASK +from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func + +logger = get_logger(__name__) + + +def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): + # create path if not exist + if not os.path.exists(save_path): + os.makedirs(save_path) + + # copy each file in files_list to save_path + for filename in files_list: + src_file = os.path.join(mode_path, filename) + + # copy only if src exists + if os.path.exists(src_file): + dest_file = os.path.join(save_path, filename) + + # copy + shutil.copy(src_file, dest_file) + print(f"Copied {filename} to {save_path}") + else: + print(f"File {filename} does not exist in {mode_path}") + + +def check_existing_ckpts(output_dir): + prefix = "step_" + + if not os.path.exists(output_dir): + return [] + # list all files and dirs + contents = os.listdir(output_dir) + + # find dirs starts with "step_" + matching_folders = [ + folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix) + ] + + return matching_folders + + +def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps): + """ + extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training + """ + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}") + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + completed_steps = int(training_difference.replace("step_", "")) + starting_epoch = completed_steps // num_update_steps_per_epoch + resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps + logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}") + + return starting_epoch, completed_steps, resume_step + + +def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): + for key, value in log_dict.items(): + summary_writer.add_scalar(f"{key}", value, completed_steps) + + +def delete_ckpts_over_limits(output_dir, saving_limit, best_step): + """delete ckpts more than saving_limits except for the best_step ckpt""" + existing_ckpts = check_existing_ckpts(output_dir) + logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}") + # sorted only step num ascendingly + ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts]) + # delete the oldest steps except for the best step at present + if len(ckpt_steps) > saving_limit: + deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step] + # print(deletable_steps[:len(ckpt_steps) - saving_limit]) + for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]: + shutil.rmtree(os.path.join(output_dir, f"step_{del_step}")) + logger.info(f"Removed ckpt step_{del_step}") + + +class MptTrainer: + """ + Multitask Pre-train/Continue-train Trainer with Full-parameters training. + """ + + def __init__( + self, + accelerator: Accelerator, + model, + model_config, + train_dataloader, + valid_dataloader, + optimizer, + lr_scheduler, + tokenizer, + num_update_steps_per_epoch, + total_train_dataset_size, + args, + ): + self.accelerator = accelerator + self.model = model + # hf model config + self.model_config = model_config + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.tokenizer = tokenizer + self.num_update_steps_per_epoch = num_update_steps_per_epoch + self.total_train_dataset_size = total_train_dataset_size + # training arguments + self.args = args + # tensorboard writer + self.summary_writer = SummaryWriter(log_dir=args.tb_dir) + + def print(self, msg: str): + """ + accelerator print, default on main process + Args: + msg: + + Returns: + + """ + self.accelerator.print(msg) + + def touch(self, batch, num_tokens=10): + """touch first and last tokens and labels for debugging usage""" + self.print( + f"step 1 batch shape: {batch['input_ids'].shape},\n" + f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" + f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}" + ) + self.print(f"first {num_tokens} input_ids and loss_mask") + for pt in range(1): + self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + + @staticmethod + def format_tensor(tensor, n): + return list(map(lambda x: round(x, n), tensor.tolist())) + + def accelerate_saving_states(self, output_dir: str, completed_steps: int): + """ + Saving lora adaptor or full checkpoint using accelerator + Args: + output_dir: exact dir for saving ckpt + completed_steps: + + Returns: + + """ + self.accelerator.wait_for_everyone() + logger.info(f"[CHECKPOINT] Saving checkpoint states") + self.accelerator.save_state(output_dir) + self.accelerator.wait_for_everyone() + + # save safetensors for direct inference if needed + if self.args.save_transformers_model: + logger.info(f"[CHECKPOINT] Saving transformers(hf) model", main_process_only=True) + unwrapped_model = self.accelerator.unwrap_model(self.model) + # self.print(f"unwrapped model type {type(unwrapped_model)}") + unwrapped_model.save_pretrained( + output_dir, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + self.accelerator.wait_for_everyone() + + # tokenizer saving and bug dummy ckpt cleaning. + if self.accelerator.is_main_process: + if self.args.model_type.lower() == "deepseek": + copy_tokenizer_files( + self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir + ) + else: + self.tokenizer.save_pretrained(output_dir) + + sf = os.path.join(output_dir, "model.safetensors") + index_file = os.path.join(output_dir, "model.safetensors.index.json") + if os.path.isfile(sf) and os.path.isfile(index_file): + self.print(f"Remove bug dummy ckpt {sf}") + os.remove(sf) + + # save latest info + if self.accelerator.is_main_process: + latest = { + "latest_ckpt": output_dir, + "lr": self.optimizer.param_groups[0]["lr"], + } + with open(os.path.join(self.args.output_dir, "latest"), "w") as f: + json.dump(latest, f, indent=2) + + logger.info( + f"[CHECKPOINT][complete_steps={completed_steps}], states {output_dir} saved, latest: {latest}", + main_process_only=True, + ) + self.accelerator.wait_for_everyone() + + def accelerate_monitor( + self, + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status=None, + ): + """ + gather reduce_loss and reduce_task_loss from all N devices. + train logging and tensorboarding. + """ + # gather reduce_loss and reduce_task_loss from all N devices + reduce_losses = self.accelerator.gather(reduce_loss).detach().float() + reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK)) + reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK)) + + # get train loss and per-task train loss + train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps) + # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps) + train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0) + + # logging and writing tensorboard + logger.info( + f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]" + f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]" + f"[gather shape={list(reduce_losses.shape)}]" + f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]", + main_process_only=True, + ) + if coba_status is not None: + if completed_steps > coba_status.coba_warmup_steps: + coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum( + coba_status.log_per_task_weight + ) + else: + coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) + logger.info( + f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True + ) + train_log_dict = {"Loss/train": train_loss} + for i in range(len(ID2TASK)): + train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i] + if coba_status is not None: + train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item() + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, train_log_dict, completed_steps) + + if coba_status is not None: + coba_status.log_per_task_weight = torch.zeros(len(ID2TASK)) + + def accelerate_evaluate( + self, + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ): + """ + evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices. + eval logging and tensorboarding. + """ + losses = [] + accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + for valid_step, valid_batch in enumerate(self.valid_dataloader): + with torch.no_grad(): + outputs = self.model( + input_ids=valid_batch["input_ids"], + attention_mask=valid_batch["attention_mask"], + position_ids=valid_batch["position_ids"], + return_dict=True, + ) + + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=valid_batch["labels"], + task_mask=valid_batch["task_mask"], + task_id=valid_batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=valid_batch["loss_mask"], + task_weights=self.args.task_weights, + ) + + losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size))) + accumulated_task_loss += task_loss.detach().float() + accumulated_task_exist += (task_loss != 0.0).detach().float() + + self.accelerator.wait_for_everyone() + valid_batch_num = len(losses) + gathered_size = losses[0].shape + losses = torch.cat(losses) + # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK)) + task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK)) + task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK)) + + try: + eval_loss = torch.mean(losses) + # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num + eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0) + if eval_loss <= min_eval_loss: + min_eval_loss = eval_loss + stall_num = 0 + best_step = completed_steps + else: + stall_num += 1 + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info( + f"[EVAL][completed_steps={completed_steps}]" + f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]" + f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]" + f"[gather_size={list(gathered_size)}]", + main_process_only=True, + ) + eval_log_dict = { + "Loss/valid": eval_loss, + "Perplexity/valid": perplexity, + "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2), + } + for i in range(len(ID2TASK)): + eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i] + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, eval_log_dict, completed_steps) + + return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step + + def accelerate_train(self): + # Train! + if self.args.seed is not None: + set_seed(self.args.seed) + + global_batch_size = ( + self.args.per_device_train_batch_size + * self.accelerator.num_processes + * self.args.gradient_accumulation_steps + ) + logger.info("************************************** Running training ****************************************") + logger.info(f" Num examples = {self.total_train_dataset_size}") + logger.info(f" Num Epochs = {self.args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}") + logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}") + logger.info("************************************************************************************************") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process) + + # set starting_epoch, completed_steps and resume_step of train_dataloader + completed_steps = 0 + starting_epoch = 0 + resume_step = None + + if self.args.resume_from_checkpoint: + self.accelerator.load_state(self.args.resume_from_checkpoint) + self.accelerator.print(f"Resumed from checkpoint: {self.args.resume_from_checkpoint}") + path = os.path.basename(self.args.resume_from_checkpoint) + starting_epoch, completed_steps, resume_step = extract_epochs_and_steps( + path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps + ) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # monitor minimum eval_loss, stalling num, and best_step + min_eval_loss = float("inf") + stall_num = 0 + best_step = None + + # monitor train loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + per_task_weight = self.args.task_weights + + if self.args.weighted_loss_mode == "coba": + self.model.eval() + eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate( + completed_steps, + 0, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + coba_status = CoBaStatus( + self.args.coba_warmup_steps, + self.args.coba_history_length, + self.args.coba_tau, + self.args.coba_update_interval, + self.args.coba_sample_valid_num, + self.valid_dataloader, + ) + coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device) + coba_status.sample_valid_batch(self.model, completed_steps) + logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + else: + coba_status = None + + # Training Loop! + for epoch in range(starting_epoch, self.args.num_train_epochs): + # set_epoch + # self.train_dataloader.set_epoch(epoch) + + # if we early stop by some ckpts not converging + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + break + + if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step) + else: + active_dataloader = self.train_dataloader + tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps + print(f"length of dataloader: {len(active_dataloader)}") + + self.model.train() + # Inner Loop! + for step, batch in enumerate(active_dataloader): + if step == tail_num: + break + with self.accelerator.accumulate(self.model): + if step == 0: + self.touch(batch, num_tokens=10) + # forward + outputs = self.model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + position_ids=batch["position_ids"], + return_dict=True, + ) + + if ( + self.args.weighted_loss_mode == "coba" + and self.accelerator.sync_gradients + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps >= self.args.coba_warmup_steps + ): + with torch.no_grad(): + per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps) + coba_status.log_per_task_weight += per_task_weight + # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True) + + # loss + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=batch["labels"], + task_mask=batch["task_mask"], + task_id=batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=batch["loss_mask"], + task_weights=per_task_weight, + ) + + # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1]) + # accelerator.print(batch['attention_mask'].shape, batch['attention_mask']) + aux_loss = None + if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits: + if hasattr(self.model_config, "num_local_experts"): + num_experts = self.model_config.num_local_experts + elif hasattr(self.model_config, "num_experts"): + num_experts = self.model_config.num_experts + else: + raise ValueError("model has no attribute num_local_experts or num_experts") + aux_loss = load_balancing_loss_func( + outputs.router_logits, + num_experts, + self.model_config.num_experts_per_tok, + batch["attention_mask"], + ) + aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device) + loss += aux_loss # make sure to reside in the same device + + # backward + self.accelerator.backward(loss) + # print(self.lr_scheduler.state_dict(), self.accelerator.process_index) + # update(sync_gradients) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + # support args.min_lr + if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr: + self.optimizer.param_groups[0]["lr"] = self.args.min_lr + + # accumulate resuce_loss and reduce_task_loss in a log_interval + if not torch.isnan(loss): + reduce_loss += loss.detach().float() + if aux_loss and not torch.isnan(aux_loss): + reduce_aux_loss += aux_loss.detach().float() + # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device) + reduce_task_loss += task_loss.detach().float() + reduce_task_exist += (task_loss != 0).detach().float() + + # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. + if self.accelerator.sync_gradients: + if ( + self.args.weighted_loss_mode == "coba" + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps >= 1 + ): + coba_status.sample_valid_batch(self.model, completed_steps) + # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + + # progress_bar.update(1) + completed_steps += 1 + # monitoring training process and logging and tensorboarding + if completed_steps % self.args.log_interval == 0: + progress_bar.update(self.args.log_interval) + if reduce_aux_loss > 0.0: + self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}") + self.accelerate_monitor( + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status, + ) + # reset reduce_loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + + # steps checkpointing + if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_states(output_dir, completed_steps) + + # steps evaluation + if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader: + self.model.eval() + eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate( + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + + # delete ckpts over args.saving_limit + if self.accelerator.is_main_process and self.args.saving_limit: + delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step) + + # early stoppin when stalling more than args.early_stopping_stall_num + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + self.print(f"[WARNING] Early stopping at {completed_steps}") + break + + if completed_steps >= self.args.max_train_steps: + break + self.accelerator.wait_for_everyone() + + # epoch checkpointing + if self.args.epoch_checkpointing: + output_dir = f"epoch_{epoch + 1}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_states(output_dir, completed_steps) + + self.summary_writer.close() diff --git a/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py new file mode 100644 index 0000000..ca4347e --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/concat_sst_bin_tokenization.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +import argparse +import multiprocessing +import os +import sys +import random +import time +import tqdm +import glob +import json +import numpy as np + + +# 将父目录的父目录加入path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +grandparent_dir = os.path.dirname(parent_dir) +sys.path.append(grandparent_dir) + +from tokenizer import init_tokenizer +from pack_encoder import PackSSTBinEncoder, load_tokenizer +from data import indexed_dataset + +from threading import Semaphore +from colorama import Fore +import lm_fmt as lmd + + +def yield_from_files(files: list, semaphore): + """ + Iterator over input documents + + :param fnames: list of filenames + """ + def yielder(fname, semaphore): + with open(fname, 'r') as f: + for line in f: + semaphore.acquire() + yield json.loads(line) + + for fname in files: + semaphore.acquire() + yield from yielder(fname, semaphore) + +def yield_from_files2(fnames: list, semaphore, sample_percent): + """ + Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / + other compressed formats. Also filters out empty documents. + + :param fnames: list of filenames + """ + def yielder(fname, semaphore): + try: + sample_interval = int(1/sample_percent) + for f in filter(lambda x: x, lmd.Reader(fname).stream_data(key=None)): + rand_value = random.randint(1, sample_interval*100) + if rand_value % sample_interval != 0: + continue + semaphore.acquire() + + #rand_value = random.randint(1, sample_interval*100) + #if rand_value % sample_interval != 0: + # yield None + + yield f + except Exception as e: + print('####Exception:', e.args) + yield None + + for fname in fnames: + semaphore.acquire() + + yield from yielder(fname, semaphore) + + +def print_example_doc(input_ids, tokenizer): + print(Fore.YELLOW + f'INPUT IDS len: {len(input_ids)}') + print(Fore.BLUE + f'INPUT IDS:\n {input_ids}\n\n') + + print(Fore.RED + f'DETOKENIZED INPUT:\n{tokenizer.decode(input_ids)}') + + +def core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file): + """ + core of Data Pack SFT processing + """ + input_ids_key = 'input_ids' + + proc_start = time.time() + total_bytes_processed = 0 + pbar = tqdm.tqdm() + sentence_droped = 0 + loss_token_cnt = 0 + + print("PRINT BEFORE STREAM PROCESS DATA") + + print_example_count = 0 + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + + # release semaphore so `yield_from_files` can add another file to the buffer + semaphore.release() + + # add each tokenized document / sentence, + # For sft, each document has only one sample + input_ids_sentence = doc[input_ids_key][0] + if len(input_ids_sentence) < 1: + sentence_droped += 1 + continue + + builder.add_item(np.array(input_ids_sentence, dtype=builder.dtype)) + builder.end_document() + #builder.finalize_without_close(output_idx_file) + #builder.add_item_and_end_document_and_finalize(np.array(input_ids_sentence, dtype=builder.dtype), output_idx_file) + + # print the first packed sample as example + if print_example_count < 1: + print_example_doc(input_ids_sentence, tokenizer) + print_example_count += 1 + + # log progress + if i % 100 == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + pbar.set_description( + f"Processed {i} documents ({i / elapsed} docs/s, {mbs} MB/s)." + ) + if i != 0: + pbar.update(100) + + # 尾部处理 + builder.finalize(output_idx_file) + + print(Fore.RED + "\ndroped docs: {}".format(sentence_droped)) + + +def process_dataset(dataset_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent): + """ + Re-organize samples in the given data path into a Data Pack file. + """ + + # get all jsonl files and corresponding reading handler + files = glob.glob(os.path.join(dataset_path, '**/*.jsonl'), recursive=True) + + # build a semaphore object to stop `yield_from_files` from getting ahead + # of encoder.encode and hence building up memory + semaphore = Semaphore(1000 + parallel_num) + + # build sample iterator + sample_iterator = yield_from_files2(files, semaphore, sample_percent) + + # load tokenizer + # tokenizer = load_tokenizer(model_path, tokenizer_type) + tokenizer = init_tokenizer(model_path) + print('TOKEN of id=2:', tokenizer.convert_ids_to_tokens(2)) + print('ID of :', tokenizer.convert_tokens_to_ids('')) + print('TOKEN of id=0:', tokenizer.convert_ids_to_tokens(0)) + print('ID of :', tokenizer.convert_tokens_to_ids('')) + + # init encoder + encoder = PackSSTBinEncoder(seq_length, model_path) + + # create writer builder + key = "input_ids" + output_prefix = os.path.join(output_path, dataset_name) + output_bin_file = "{}_{}.bin".format( + output_prefix, key + ) + output_idx_file = "{}_{}.idx".format( + output_prefix, key + ) + builder = indexed_dataset.make_builder( + output_bin_file, + impl="mmap", + vocab_size=tokenizer.vocab_size, + ) + + if parallel_num > 1: + pool = multiprocessing.Pool(parallel_num, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, sample_iterator, chunksize=32) + else: + encoder.initializer() + encoded_docs = (encoder.encode(doc) for doc in sample_iterator) + + if dataset_name is None: + dataset_path = dataset_path[:-1] if dataset_path.endswith(os.path.sep) else dataset_path + dataset_name = dataset_path.split(os.path.sep)[-1] + + core_process(encoded_docs, semaphore, seq_length, tokenizer, encoder, builder, output_idx_file) + + +def main(data_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent): + """ + Entry + """ + + process_dataset(data_path, output_path, model_path, parallel_num, seq_length, dataset_name, sample_percent) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a packed jsonl file in the Data Pack SFT way.") + parser.add_argument('--model-path', type=str, help='Path of a pretrained model which contains tokenizer-related files.') + parser.add_argument('--parallel', type=int, default=1, help='The num of parallel processing.') + parser.add_argument('--output-path', type=str, help='Path to store the genered result file.') + parser.add_argument('--data-path', type=str, default=None, help='Path of files to be processed') + parser.add_argument('--seq-length', type=int, default=4096, help='The max input length (i.e. the max number of tokens in a sample)') + # parser.add_argument('--eod-token-id', type=int, default=2, help='EOD token id') + # parser.add_argument('--pad-token-id', type=int, default=0, help='PAD token id') + # parser.add_argument('--tokenizer-type', type=str, choices=["LLAMATokenizer", None], default=None, help="What type of tokenizer to use. Default is None.") + parser.add_argument('--dataset-name', type=str, default=None, help='The generated result dataset name. The folder name will be token by default.') + parser.add_argument('--sample-percent', type=float, default=1.0, help='Sample percentage') + + args = parser.parse_args() + print('ARGS\n', '\n'.join([str(key) + ':' + str(value) for key,value in vars(args).items()])) + + random.seed(9999) + + main(args.data_path, args.output_path, args.model_path, args.parallel, args.seq_length, args.dataset_name, args.sample_percent) diff --git a/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py b/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py new file mode 100644 index 0000000..c922859 --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/lm_fmt.py @@ -0,0 +1,360 @@ +import os +import zstandard +import ujson as json +import time +import tarfile +import codecs +from functools import reduce +import jsonlines +import io +from zipfile import ZipFile +import gzip +from math import ceil +import mmap +import multiprocessing as mp +from pathlib import Path + +VALID_EXTENSIONS = ['openwebtext.tar.xz', '_data.xz', '.dat.zst', '.jsonl', '.jsonl.zst', '.jsonl.zst.tar', '.json.zst', '.txt', '.zip', '.tar.gz', '.json.gz', '.gz'] + +def has_valid_extension(file): + return any([file.endswith(ext) for ext in VALID_EXTENSIONS]) + +def _listdir_or_file(x): + if isinstance(x, list): + return reduce(lambda x, y: x + y, map(listdir_or_file, sorted(x))) + if os.path.isfile(x): + return [x] + elif os.path.isdir(x): + return [str(Path(x) / fn) for fn in sorted(os.listdir(x))] + else: + raise FileNotFoundError(f"{x} not found") + +def listdir_or_file(x): + return list(filter(has_valid_extension, _listdir_or_file(x))) + +def tarfile_reader(file, streaming=False): + # we need our own tarfile parser because `tarfile` doesn't work well for + # big tarfiles; it seems to be reading the entire file to get a list of + # where all the files are - but we don't need that because we just need + # to see each file once. surprisingly, `tarfile` doesn't expose any + # facilities for this. the only options are 1. load the entire tarfile + # and then query by filename or 2. extract to disk - and neither of + # these is what we want. + + offset = 0 + paxfilesize = None + while True: + hdr = file.read(512) + offset += 512 + + # https://www.gnu.org/software/tar/manual/html_node/Standard.html + # end at 135 not 136 because of \0 terminator + if hdr[124:135] == b'\0'*11: + # end of record + break + + fname = hdr[:100].split(b'\0')[0] + + # if the file is too big to fit in the size field, tarfiles will actually + # include a PaxHeader with the size in it, applicable to the immediate next file. + if paxfilesize is not None: + size = paxfilesize + paxfilesize = None + else: + size = int(hdr[124:135], 8) + + padded_size = ceil(size / 512) * 512 + + # for handling PaxHeader files (which contain extra metadata about file size) and directories + # https://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_03 + type = chr(hdr[156]) + + if type == 'x': + meta = file.read(padded_size)[:size] + def kv(x): + return x.decode('utf-8').split(' ')[1].split('=') + paxfileattrs = { + kv(x)[0]: kv(x)[1] + for x in meta.split(b'\n') if x + } + paxfilesize = int(paxfileattrs['size']) + + offset += padded_size + continue + elif type != '0' and type != '\0': + if streaming: + file.seek(padded_size, os.SEEK_CUR) + else: + file.read(padded_size) + offset += padded_size + continue + + if streaming: + # skip directory entries + if size != 0: + mmo = mmap.mmap(file.fileno(), length=offset + size, access=mmap.ACCESS_READ) + mmo.seek(offset) + yield mmo + + file.seek(padded_size, os.SEEK_CUR) + else: + yield file.read(padded_size)[:size] + offset += padded_size + +def handle_jsonl(jsonl_reader, get_meta, autojoin_paragraphs, para_joiner, key='text'): + for ob in jsonl_reader: + # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. + if isinstance(ob, str): + assert not get_meta + yield ob + continue + + if key is None: + yield ob + continue + + text = ob[key] + + if autojoin_paragraphs and isinstance(text, list): + text = para_joiner.join(text) + + if get_meta: + yield text, (ob['meta'] if 'meta' in ob else {}) + else: + yield text + + +class Reader: + def __init__(self, in_path): + self.in_path = in_path + + def stream_data(self, get_meta=False, threaded=False, key=None): + if not threaded: + yield from self._stream_data(get_meta, key=key) + return + + q = mp.Queue(1000) + p = mp.Process(target=self._stream_data_threaded, args=(q, get_meta), kwargs={"key": key}) + p.start() + while p.is_alive(): + res = q.get() + if res is None: break + yield res + + def _stream_data_threaded(self, q, get_meta=False): + for data in self._stream_data(get_meta): + q.put(data) + q.put(None) + + def _stream_data(self, get_meta=False, key="text"): + self.f_name = "" + files = listdir_or_file(self.in_path) + if not files: + raise FileNotFoundError(f"No valid file(s) found in {self.in_path}") + for f in files: + self.f_name = f + if f == 'openwebtext.tar.xz': + assert not get_meta + + yield from self.read_owt(f) + elif 'urlsf_subset' in f and f.endswith('_data.xz'): + assert not get_meta + + yield from self.read_owt_subset(f) + elif f.endswith('.dat.zst'): + assert not get_meta + + yield from self.read_dat(f) + elif f.endswith('.jsonl'): + yield from self.read_jsonl(f, get_meta, key=key) + elif f.endswith('.jsonl.zst'): + yield from self.read_jsonl_zst(f, get_meta, key=key) + elif f.endswith('.jsonl.zst.tar'): + yield from self.read_jsonl_tar(f, get_meta, key=key) + elif f.endswith('.json.zst'): + assert not get_meta + + yield from self.read_json(f) + elif f.endswith('.txt'): + assert not get_meta + + yield from self.read_txt(f) + elif f.endswith('.zip'): + assert not get_meta + + yield from self.read_zip(f) + elif f.endswith('.tar.gz'): + assert not get_meta + + yield from self.read_tgz(f) + elif f.endswith('.json.gz'): + assert not get_meta + + yield from self.read_jsongz(f) + elif f.endswith('.gz'): + assert not get_meta + + yield from self.read_gz(f) + else: + # shouldn't be reached + print(f'Skipping {f} as streaming for that filetype is not implemented') + + def read_txt(self, file): + with open(file, 'r') as fh: + yield fh.read() + + def read_zip(self, file): + archive = ZipFile(file, 'r') + for f in archive.namelist(): + yield archive.read(f).decode('UTF-8') + + def read_tgz(self, file): + gz = gzip.open(file) + yield from (x.decode('utf-8') for x in tarfile_reader(gz, streaming=False)) + + def read_gz(self, file): + with gzip.open(file, 'rb') as f: + for line in f: + yield line.decode('utf-8') + + def read_jsongz(self, file): + for line in self.read_gz(file): + yield json.loads(line) + + def read_json(self, file): + with open(file, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = cctx.stream_reader(fh) + ob = json.load(reader) + yield from ob + + def read_dat(self, file): + with open(file, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = cctx.stream_reader(fh) + while True: + ln = reader.read(16).decode('UTF-8') + if not ln: + break + + ln = int(ln) + + yield reader.read(ln).decode('UTF-8') + + def read_jsonl(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'): + with jsonlines.open(file) as rdr: + yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key) + + def read_jsonl_zst(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'): + with open(file, 'rb') as fh: + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(fh)) + rdr = jsonlines.Reader(reader) + yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key) + + def read_jsonl_tar(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n', key='text'): + with open(file, 'rb') as fh: + for f in tarfile_reader(fh, streaming=True): + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(f)) + rdr = jsonlines.Reader(reader) + yield from handle_jsonl(rdr, get_meta, autojoin_paragraphs, para_joiner, key) + f.close() + + def read_owt(self, file): + tar = tarfile.open(file, encoding='utf-8') + utf8reader = codecs.getreader('utf-8') + + for name in tar.getmembers(): + fp = tar.extractfile(name) + inner_tar = tarfile.open(fileobj=fp, encoding='utf-8') + for inner_name in inner_tar.getmembers(): + inner_fp = utf8reader(inner_tar.extractfile(inner_name)) + contents = inner_fp.read() + yield contents + + def read_owt_subset(self, file): + utf8reader = codecs.getreader('utf-8') + tar = tarfile.open(file, encoding='utf-8') + for name in tar.getmembers(): + fp = utf8reader(tar.extractfile(name)) + contents = fp.read() + yield contents + + +class Archive: + def __init__(self, out_dir, compression_level=3, threads=8): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.i = 0 + + self.fh = open(self.out_dir + '/current_chunk_incomplete', 'wb') + self.cctx = zstandard.ZstdCompressor(level=compression_level, threads=threads) + self.compressor = self.cctx.stream_writer(self.fh) + + + def add_data(self, data, meta={}): + self.compressor.write(json.dumps({'text': data, 'meta': meta}).encode('UTF-8') + b'\n') + + def commit(self, archive_name='default'): + fname = self.out_dir + '/data_' + str(self.i) + '_time' + str(int(time.time())) + '_' + archive_name + '.jsonl.zst' + self.compressor.flush(zstandard.FLUSH_FRAME) + + self.fh.flush() + self.fh.close() + os.rename(self.out_dir + '/current_chunk_incomplete', fname) + self.fh = open(self.out_dir + '/current_chunk_incomplete', 'wb') + self.compressor = self.cctx.stream_writer(self.fh) + + self.i += 1 + + +class DatArchive: + def __init__(self, out_dir): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.data = [] + self.i = 0 + if os.path.exists(out_dir) and len(os.listdir(out_dir)) > 0: + self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self, archive_name=None): + # TODO: streaming + cctx = zstandard.ZstdCompressor(level=3) + + if archive_name is None: + archive_name = str(int(time.time())) + + res = b''.join(map(lambda x: ("%016d" % len(x)).encode('UTF-8') + x, map(lambda x: x.encode('UTF-8'), self.data))) + cdata = cctx.compress(res) + + with open(self.out_dir + '/data_' + str(self.i) + '_' + archive_name + '.dat.zst', 'wb') as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] + +class JSONArchive: + def __init__(self, out_dir): + self.out_dir = out_dir + os.makedirs(out_dir, exist_ok=True) + self.data = [] + self.i = 0 + if os.path.exists(out_dir) and len(os.listdir(out_dir)) > 0: + self.i = max(map(lambda x: int(x.split('_')[1].split('.')[0]), os.listdir(out_dir))) + 1 + + def add_data(self, data): + self.data.append(data) + + def commit(self): + cctx = zstandard.ZstdCompressor(level=3) + + cdata = cctx.compress(json.dumps(self.data).encode('UTF-8')) + with open(self.out_dir + '/data_' + str(self.i) + '_' + str(int(time.time())) + '.json.zst', 'wb') as fh: + fh.write(cdata) + + self.i += 1 + self.data = [] diff --git a/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py b/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py new file mode 100644 index 0000000..0678e27 --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/pack_encoder.py @@ -0,0 +1,335 @@ +from transformers import AutoTokenizer +from tokenizer import init_tokenizer + + +def load_tokenizer(model_path, tokenizer_type=None): + """ + Load tokenizer from the given + """ + + def load_tokenizer_manual(model_path, tokenizer_type): + """ + Load tokenizer by the concrete Tokenizer class instead of AutoTokenizer + """ + try: + if tokenizer_type.lower() == "LlamaTokenizer".lower(): + return LlamaTokenizer.from_pretrained(model_path) + + raise Exception(f"Unsupported tokenizer type {tokenizer_type}") + except: + raise Exception(f"Unable to load tokenizer {tokenizer_type} from the given path: {model_path}") + + def load_tokenizer_auto(model_path): + """ + Load tokenizer from the given path by HuggingFace AutoTokenizer + """ + try: + # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) # support CodeLlama + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + return tokenizer + except: + raise Exception( + f'Unable to load tokenizer from the given path: {model_path} using auto mode.\nPlease specify the tokenizer type with the command argument "--tokenizer-type" and retry.' + ) + + # First, try to load tokenizer by huggingface AutoTokenizer, If fail, try another manual way + try: + return load_tokenizer_auto(model_path) + except Exception as e: + print(str(e)) + if tokenizer_type is not None: + try: + tokenizer = load_tokenizer_manual(model_path, tokenizer_type) + return tokenizer + except Exception as ee: + raise ee + + +class PackPFTEncoder: + """ + A sample of this format will be: + <|role_start|>system<|role_end|> content of system_1 + <|role_start|>human<|role_end|> content of human_1 + <|role_start|>bot<|role_end|> content of bot_1 + <|endoftext|> + <|role_start|>system<|role_end|> content of system_2 + <|role_start|>human<|role_end|> content of human_2 + <|role_start|>bot<|role_end|> content of bot_2 + <|endoftext|> + <|role_start|>human<|role_end|> content of human_3 + <|role_start|>bot<|role_end|> content of bot_3 + <|endoftext|> + .... + <|role_start|>human<|role_end|> content of human_n + <|role_start|>bot<|role_end|> content of bot_n + <|endoftext|> + + <|pad|><|pad|>...<|pad|> + + system part is optional, i.e. '<|role_start|>system<|role_end|> content of system_i' + """ + + def __init__(self, seq_length, eod_token_id, pad_token_id, role_start_tag, role_end_tag, mode="pft"): + self.mode = mode + self.seq_length = seq_length + self.eod_token_id = eod_token_id + self.pad_token_id = pad_token_id + self.role_start_tag = role_start_tag + self.role_end_tag = role_end_tag + + def initializer(self, model_path, tokenizer_type=None): + # Use Encoder class as a container for global data + assert model_path is not None + self.tokenizer = load_tokenizer(model_path, tokenizer_type) + + def encode(self, item): + encode_res = { + "input_ids": [], + } + + item_len = sum([len(x["content"]) for x in item["chat_rounds"]]) + for token_res in self.tokenize_chat_prompt(item): + for k, v in token_res.items(): + encode_res[k].append(v) + return encode_res, item_len + + def tokenize_chat_prompt(self, item): + # role_start_marker = self.tokenizer.encode(self.role_start_tag, add_special_tokens=False) + # role_end_marker = self.tokenizer.encode(self.role_end_tag, add_special_tokens=False) + end_marker = [self.eod_token_id] + + input_ids = [] + raw_input = "" + # loss_mask = [] + for chat_round in item["chat_rounds"]: + role = chat_round["role"].strip() + # skip system prompt + # if role == 'system': + # continue + + content = chat_round["content"] + content = content if content.endswith("\n") else f"{content}\n" + text = f"{self.role_start_tag}{role}{self.role_end_tag}{content}" + chat_input_ids = self.tokenizer.encode(text, add_special_tokens=False) + + if role != "bot": + chat_input_ids = chat_input_ids + else: + chat_input_ids = chat_input_ids + end_marker + + input_ids += chat_input_ids + + # if this sample's length is more than the specified max length, drop it + # here, we don't add padding tokens for a single sample, however, we will append padding tokens for a combinated samaple + if len(input_ids) > self.seq_length: + yield {} + else: + yield {"input_ids": input_ids} + + def padding(self, key, data): + assert len(data) <= self.seq_length, f"padding sequence: {len(data)} > {self.seq_length}" + if key == "input_ids": + return data + [self.pad_token_id] * (self.seq_length - len(data)) + + if key == "loss_mask": + return data + [0] * (self.seq_length - len(data)) + + raise Exception("Should not reach here. There must be something wrong.") + + +class PackSFTEncoder: + """ + A sample of this format will be: + <|role_start|>system<|role_end|> content of system_1 + <|role_start|>human<|role_end|> content of human_1 + <|role_start|>bot<|role_end|> content of bot_1 + <|endoftext|> + <|role_start|>system<|role_end|> content of system_2 + <|role_start|>human<|role_end|> content of human_2 + <|role_start|>bot<|role_end|> content of bot_2 + <|endoftext|> + <|role_start|>human<|role_end|> content of human_3 + <|role_start|>bot<|role_end|> content of bot_3 + <|endoftext|> + .... + <|role_start|>human<|role_end|> content of human_n + <|role_start|>bot<|role_end|> content of bot_n + <|endoftext|> + + <|pad|><|pad|>...<|pad|> + + system part is optional, i.e. '<|role_start|>system<|role_end|> content of system_i' + """ + + def __init__(self, seq_length, eod_token, role_start_tag, role_end_tag, mode="sft"): + self.mode = mode + self.seq_length = seq_length + self.eod_token = eod_token + self.role_start_tag = role_start_tag + self.role_end_tag = role_end_tag + + def initializer(self, model_path, tokenizer_type=None): + # Use Encoder class as a container for global data + assert model_path is not None + self.tokenizer = load_tokenizer( + model_path, tokenizer_type + ) # AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + def encode(self, item): + encode_res = {"input_ids": [], "raw_input": []} + + item_len = sum([len(x["content"]) for x in item["chat_rounds"]]) + for token_res in self.tokenize_chat_prompt(item): + for k, v in token_res.items(): + encode_res[k].append(v) + return encode_res, item_len + + def tokenize_chat_prompt(self, item): + role_start_marker = self.tokenizer.encode(self.role_start_tag, add_special_tokens=False) + role_end_marker = self.tokenizer.encode(self.role_end_tag, add_special_tokens=False) + end_marker = [self.tokenizer.convert_tokens_to_ids(self.eod_token)] + + input_ids = [] + raw_input = "" + # loss_mask = [] + for chat_round in item["chat_rounds"]: + role = chat_round["role"] + content = chat_round["content"] + content = content if content.endswith("\n") else f"{content}\n" + chat_input_ids = self.tokenizer.encode(content, add_special_tokens=False) + role_input_ids = self.tokenizer.encode(role, add_special_tokens=False) + role_raw_input = "" + + if role != "bot": + # chat_loss_mask = [0] * len(role_start_marker) + [0] * len(role_input_ids) + [0] * len(role_end_marker) + [0] * len(chat_input_ids) + chat_input_ids = role_start_marker + role_input_ids + role_end_marker + chat_input_ids + role_raw_input = ROLE_START_MARKER + role + ROLE_END_MARKER + content + elif role == "human": + # chat_loss_mask = [0] * len(role_start_marker) + [0] * len(role_input_ids) + [0] * len(role_end_marker) + [1] * len(chat_input_ids) + [1] * len(end_marker) + chat_input_ids = role_start_marker + role_input_ids + role_end_marker + chat_input_ids + end_marker + role_raw_input = ROLE_START_MARKER + role + ROLE_END_MARKER + content + self.eod_token + + input_ids += chat_input_ids + raw_input += role_raw_input + # loss_mask += chat_loss_mask + + # assert len(input_ids) == len(loss_mask) + + # if this sample's length is more than the specified max length, drop it + # here, we don't add padding tokens for a single sample, however, we will append padding tokens for a combinated samaple + if len(input_ids) > self.seq_length: + yield {} + else: + yield { + "input_ids": input_ids, + "raw_input": raw_input, + # "loss_mask": loss_mask + } + + def padding(self, key, data, pad_token_id): + assert len(data) <= self.seq_length, f"padding sequence: {len(data)} > {self.seq_length}" + if key == "input_ids": + return data + [pad_token_id] * (self.seq_length - len(data)) + + if key == "loss_mask": + return data + [0] * (self.seq_length - len(data)) + + raise Exception("Should not reach here. There must be something wrong.") + + +class PackSSTBinEncoder: + """ + A sample of this format will be: + content of sample_1 + content of sample_2 + ... + content of sample_n + <|pad|><|pad|>...<|pad|> + """ + + def __init__(self, seq_length, model_path): + self.seq_length = seq_length + self.model_path = model_path + + def initializer(self): + # Use Encoder class as a container for global data + assert self.model_path is not None + # self.tokenizer = load_tokenizer(model_path, tokenizer_type) #AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + # PackSSTBinEncoder.tokenizer = load_tokenizer(self.model_path, self.tokenizer_type) + PackSSTBinEncoder.tokenizer = init_tokenizer(self.model_path) + + def _encode_content(self, item, encode_res): + if "content" in item: + content = item["content"] + else: + content = item["text"] + + item_len = len(content) + + input_ids = self.tokenize_string(content) + encode_res["input_ids"].append(input_ids) + + return encode_res, item_len + + def _encode_chatml(self, item, encode_res): + input_ids = [] + item_len = 0 + one_round_content = "" + for i in range(len(item["chat_rounds"])): + chat_round = item["chat_rounds"][i] + role = chat_round["role"] + content = chat_round["content"] + content = content if content.endswith("\n") else f"{content}\n" + if role.lower() == "system": + continue + if role.lower() == "human": + one_round_content = content + else: + one_round_content += content + input_ids += self.tokenize_string(one_round_content) + item_len += len(one_round_content) + + encode_res["input_ids"].append(input_ids) + + return encode_res, item_len + + def encode(self, item): + encode_res = { + "input_ids": [], + } + + try: + if item is None: + encode_res["input_ids"].append([]) + return encode_res, 0 + + if "content" in item or "text" in item: + return self._encode_content(item, encode_res) + + if "chat_rounds" in item: + return self._encode_chatml(item, encode_res) + except Exception as e: + print("####JSON Exception", e, str(item)) + encode_res["input_ids"].append([]) + return encode_res, 0 + + raise Exception("Unsupported Format!") + + def tokenize_string(self, text): + end_marker = [PackSSTBinEncoder.tokenizer.eos_token_id] + + input_ids = [] + try: + input_ids = PackSSTBinEncoder.tokenizer.encode(text, add_special_tokens=False) + input_ids = input_ids + end_marker + return input_ids + except Exception as e: + print("####Tokenization Exception:", e, text) + return [] + except BaseException as e: + print("####Tokenization BaseException:", e, "Length of text", len(text)) + return [] + + def padding(self, data, pad_token_id): + assert len(data) <= self.seq_length, f"padding sequence: {len(data)} > {self.seq_length}" + return data + [pad_token_id] * (self.seq_length - len(data)) diff --git a/mftcoder_accelerate/src/offline_tokenization/writer.py b/mftcoder_accelerate/src/offline_tokenization/writer.py new file mode 100644 index 0000000..ab526a7 --- /dev/null +++ b/mftcoder_accelerate/src/offline_tokenization/writer.py @@ -0,0 +1,42 @@ + +import threading +import fcntl +import json + +class JSONLWriter(): + """ + A writer used to save jsonl lines into a file. + """ + def __init__(self, output_path, dataset_name): + self.output_path = output_path + self.out_file = open(output_path, 'w') + self.cache = [] + self.cache_size = 4096 + self.dataset_name = dataset_name + self.index = 0 + + def pack_into_jsonl(self, line_text): + new_item = { + "data_name": self.dataset_name, + "id": self.index, + "content": line_text + } + + return new_item + + + def add_item(self, line_text): + if len(self.cache) >= self.cache_size: + self.flush() + + item = self.pack_into_jsonl(line_text) + self.cache.append(json.dumps(item)) + self.index += 1 + + + def flush(self): + content = '\n'.join(self.cache) + fcntl.flock(self.out_file, fcntl.LOCK_EX) + self.out_file.write(f'{content}\n') + fcntl.flock(self.out_file, fcntl.LOCK_UN) + self.cache = [] diff --git a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py index bf8434f..26f8ec1 100644 --- a/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py +++ b/mftcoder_accelerate/src/pefts/merge_base_and_lora_to_hf.py @@ -4,6 +4,7 @@ Merge base and lora adaptor """ + import os import sys import time @@ -16,13 +17,13 @@ from peft import LoraConfig, get_peft_model from peft import PeftModel -# insert src as import path +# insert src as import path current_path = os.path.abspath(__file__) parent_dir = os.path.dirname(os.path.dirname(current_path)) sys.path.insert(0, parent_dir) print("In merge_base_and_lora_to_hf.py, sys path:", sys.path) -from pefts.model_mapping import MODEL_SPECIAL_TOKENS +from tokenizer import init_tokenizer def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): @@ -42,7 +43,7 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): print(f"File {filename} does not exist in {mode_path}") -if __name__ == '__main__': +if __name__ == "__main__": # arguments parser = argparse.ArgumentParser() @@ -58,29 +59,21 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): save_path = args.merged_output_path t0 = time.time() - config = {"model_type": model_type} - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + tokenizer = init_tokenizer(args.base_model_or_path) base_model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, + # torch_dtype=torch.float32, return_dict=True, - device_map="auto" + device_map="auto", ) print("--------------------------------------Base Model--------------------------------------------") print(base_model) print("--------------------------------------------------------------------------------------------") - # DEAL with eos_token_id and pad_token_id - eos_token = MODEL_SPECIAL_TOKENS[config['model_type']]['eos_token'] - pad_token = MODEL_SPECIAL_TOKENS[config['model_type']]['pad_token'] - base_model.config.eos_token = eos_token - base_model.config.pad_token = pad_token - base_model.config.eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) - base_model.config.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) - print(f"Finetuned eos_token: {eos_token}, eos_token_id: {tokenizer.convert_tokens_to_ids(eos_token)}") - print(f"Finetuned pad_token: {pad_token}, pad_token_id: {tokenizer.convert_tokens_to_ids(pad_token)}") print("-----------------------------------Base Model Config----------------------------------------") print(base_model.config) print("--------------------------------------------------------------------------------------------") @@ -88,6 +81,8 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): # merge, save model and tokenizer model_to_merge = PeftModel.from_pretrained(base_model, lora_adapter) merged_model = model_to_merge.merge_and_unload() + # merged_model.to(torch.bfloat16) + print("---------------------------------Merged Model Config----------------------------------------") print(merged_model.config) print("--------------------------------------------------------------------------------------------") @@ -99,8 +94,8 @@ def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): if model_type.lower() == "deepseek": copy_tokenizer_files( model_path, - ["tokenizer.model", "tokenizer.json", "tokenizer_config.json", 'special_tokens_map.json'], - save_path + ["tokenizer.model", "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"], + save_path, ) else: tokenizer.save_pretrained(save_path) diff --git a/mftcoder_accelerate/src/pefts/mft_accelerate.py b/mftcoder_accelerate/src/pefts/mft_accelerate.py index 8d3eb18..0a0d42a 100644 --- a/mftcoder_accelerate/src/pefts/mft_accelerate.py +++ b/mftcoder_accelerate/src/pefts/mft_accelerate.py @@ -1,15 +1,13 @@ """ # @author Chaoyu Chen -# @date 2023/12/11 +# @date 2024/10/24 # @module mft_accelerate.py -Accelerate + DeepSpeed/FSDP -QLoRA/LoRA/Full + MFT/MPT, accurate and efficient training +Accelerate + DeepSpeed/FSDP + QLoRA/LoRA/Full + Multi-task Finetuning Entry """ -import gc import os import sys import argparse @@ -17,6 +15,7 @@ import logging import json import time +from tqdm.auto import tqdm import transformers import numpy as np import torch @@ -26,11 +25,10 @@ import datasets from torch.utils.data import DataLoader from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig -from tqdm.auto import tqdm + from transformers import ( AutoModelForCausalLM, AutoTokenizer, - LlamaTokenizer, get_linear_schedule_with_warmup, set_seed, BitsAndBytesConfig, @@ -43,8 +41,11 @@ prepare_model_for_kbit_training, PeftModel, ) -from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin +from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration from accelerate.logging import get_logger +from datetime import timedelta +from accelerate.utils import InitProcessGroupKwargs +from transformers.optimization import Adafactor # insert src as import path current_path = os.path.abspath(__file__) @@ -56,13 +57,13 @@ from data.multi_task_dataset import load_dataset_from_jsonl, compile_helper from data.data_utils import load_dataset_from_bin from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK -from pefts.train_utils import accelerate_train -from pefts.arguments import TrainArgs -from pefts.model_mapping import MODEL_TYPES, FULL_LORA_TARGETING_MODULES, MODEL_SPECIAL_TOKENS -logger = get_logger(__name__) +from pefts.mft_trainer import MftTrainer +from pefts.mft_arguments import MftTrainArgs +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS -SUPPORT_FA2_IN_TRANSFORMERS = ["code_llama", "llama", "deepseek", "mistral", "mixtral", "gpt_neox", "phi", "starcoder"] + +logger = get_logger(__name__) def get_task_mask(args, task_id): @@ -74,7 +75,7 @@ def get_task_mask(args, task_id): def get_attention_mask_and_position_ids(data): - """Build masks and position ids if you need to""" + """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() @@ -93,47 +94,59 @@ class DataCollatorForMFTDataset(object): args: None def __call__(self, instances): - input_ids, loss_mask, weights, task_id = tuple( - [instance[key] if key in instance else None for instance in instances] for key in - ("input_ids", "loss_mask", "weight", "task_id")) + (input_ids, loss_mask, weights, task_id) = tuple( + [instance.get(key, None) for instance in instances] + for key in ("input_ids", "loss_mask", "weight", "task_id") + ) result_batch = {} - ''' + """ outputs = model( - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']), - # labels=(batch['labels'], batch['loss_mask']), - position_ids=batch['position_ids'], - ) - ''' + input_ids=batch['input_ids'], + attention_mask=batch['attention_mask'], + # labels=(batch['labels'], batch['loss_mask'], batch['task_mask']), + # labels=(batch['labels'], batch['loss_mask']), + position_ids=batch['position_ids']) + """ # if loss_mask is not None: loss_mask = torch.tensor(np.array(loss_mask)).long() + last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) if self.args.use_dynamic_padding: - last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) # get last non-padding position max_pos = last_one_pos.max().item() + 1 else: max_pos = loss_mask.shape[-1] - result_batch['loss_mask'] = loss_mask.float()[:, 1:max_pos].contiguous() - input_ids = torch.tensor(np.array(input_ids)).long() - # print(f"shape of input_ids: {input_ids.shape}") - result_batch['input_ids'] = input_ids[:, :max_pos - 1].contiguous() - result_batch['labels'] = input_ids[:, 1:max_pos].contiguous() + if self.args.tokenize_mode == "sst" and self.args.padding_mode == "pack": + # sst + pack tokenization, remove last dirty data + result_batch["loss_mask"] = loss_mask.float()[:, 1 : max_pos - 1].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + result_batch["input_ids"] = input_ids[:, : max_pos - 2].contiguous() + result_batch["labels"] = input_ids[:, 1 : max_pos - 1].contiguous() + else: + result_batch["loss_mask"] = loss_mask.float()[:, 1:max_pos].contiguous() + input_ids = torch.tensor(np.array(input_ids)).long() + # print(f"shape of input_ids: {input_ids.shape}") + result_batch["input_ids"] = input_ids[:, : max_pos - 1].contiguous() + result_batch["labels"] = input_ids[:, 1:max_pos].contiguous() # Get the masks and position ids. - # For decoder-only models, attention_mask and position_ids should be None and transformers will create them. - result_batch['attention_mask'], result_batch['position_ids'] = None, None - - # if you want to be compatible with non-gpt(non-causal)models, something you can do here - # result_batch['attention_mask'], result_batch['position_ids'] = get_attention_mask_and_position_ids(data=result_batch['input_ids']) + if self.args.model_type in ["mixtral", "qwen2_moe"]: + batch_size, seq_length = result_batch["input_ids"].shape + # bsz * seq_length + range_tensor = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1) + # attention_mask for padding tokens + attention_mask = (range_tensor <= last_one_pos.reshape(batch_size, 1)).long() + result_batch["attention_mask"], result_batch["position_ids"] = attention_mask, None + else: + # For decoder-only models, transformers will create them. + result_batch["attention_mask"], result_batch["position_ids"] = None, None if task_id is not None: task_id = torch.tensor(np.array(task_id)) - result_batch['task_mask'] = get_task_mask(self.args, task_id) # bsz * task_num - result_batch['task_id'] = task_id + result_batch["task_mask"] = get_task_mask(self.args, task_id) # bsz * task_num + result_batch["task_id"] = task_id return result_batch @@ -144,7 +157,7 @@ def pprint_args(args, accelerator): message = "" message += "====" * 60 + "\n" - message += '\n'.join([f'{k:<{max_key_length}} : {v}' for k, v in vars(args).items()]) + "\n" + message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" message += "====" * 60 + "\n" accelerator.print(message) accelerator.print("GPU: {}".format(torch.cuda.current_device())) @@ -164,12 +177,12 @@ def prepare_args(): parsed = parser.parse_args() # get json configs - with open(parsed.train_config, 'r') as f: + with open(parsed.train_config, "r") as f: train_config = json.load(f) # parse args from cofig.json # args = argparse.Namespace(**train_config) - args = TrainArgs(**train_config) + args = MftTrainArgs(**train_config) # override args by cli arguments if parsed.data_paths: @@ -190,26 +203,25 @@ def prepare_args(): args.distributed_type = parsed.distributed_type # refactor args - args.eos_token = MODEL_SPECIAL_TOKENS[args.model_type]['eos_token'] - args.pad_token = MODEL_SPECIAL_TOKENS[args.model_type]['pad_token'] - if args.peft_type == 'qlora' and args.quantization != '4bit' and args.quantization != '8bit': - print(f"[WARNING]peft_type is qlora but quantization is not 4bit or 8bit, setting it to 4bit") - args.quantization = '4bit' + if args.peft_type == "qlora": + print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'") + args.quantization = "4bit" + else: + args.quantization = None args.vocab_file = args.pretrained_model_path - args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(','))) + "]" + args.data_weights = "[" + ",".join(["1."] * len(args.data_paths[1:-1].split(","))) + "]" # generate TASK2ID, ID2TASK generate_task_id(args.data_paths) - if args.weighted_loss_mode == 'selfpaced': + if args.weighted_loss_mode == "coba": args.task_weights = [1.0] * len(ID2TASK) elif args.task_weights is not None: args.task_weights = [float(wt) for wt in args.task_weights[1:-1].split(",")] - assert len(args.task_weights) == len( - ID2TASK), f"length of task_weights, is not equal to the length of data_paths" + assert len(args.task_weights) == len(ID2TASK), f"length of task_weights must equal to length of data_paths" else: args.task_weights = [1.0] * len(ID2TASK) @@ -219,31 +231,46 @@ def prepare_args(): def main(): t0 = time.time() os.environ["TOKENIZERS_PARALLELISM"] = "false" - print(f"transformers.__version__: {transformers.__version__}") - + os.environ["HF_HUB_OFFLINE"] = "false" # get input args, set TASK2ID, ID2TASK, refactor args args = prepare_args() + # fix randomness + if args.seed is not None: + set_seed(args.seed) + # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + if args.distributed_type and args.distributed_type.lower() == "fsdp": fsdp_plugin = FullyShardedDataParallelPlugin( # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers=True, sync_module_states=True, - cpu_offload=False + use_orig_params=True, + cpu_offload=False, + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], ) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, fsdp_plugin=fsdp_plugin) else: - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + + # print key infos + accelerator.print("In mft_accelerate.py, sys path:", sys.path) + accelerator.print(f"transformers.__version__: {transformers.__version__}") # get world_size args.world_size = accelerator.num_processes - # fix randomness - if args.seed is not None: - set_seed(args.seed) - # backup args pprint_args(args, accelerator) if accelerator.is_main_process: @@ -252,9 +279,26 @@ def main(): with open(os.path.join(args.output_dir, "args.json"), "w") as f: json.dump(args.dict(), f, indent=2) + # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest + latest = None + if os.path.exists(os.path.join(args.output_dir, "latest")): + with open(os.path.join(args.output_dir, "latest"), "r") as fl: + latest = json.load(fl) + accelerator.print(f"[INFO] Existing latest: {latest}") + + if args.auto_resume and args.resume_from_checkpoint is None and latest: + if args.peft_type: + args.resume_from_checkpoint = latest["latest_ckpt"] + else: + args.resume_from_checkpoint = latest["latest_ckpt"] + args.pretrained_model_path = args.resume_from_checkpoint + args.learning_rate = latest["lr"] + elif args.resume_from_checkpoint and (not args.peft_type): + args.pretrained_model_path = args.resume_from_checkpoint + # logger logging.basicConfig( - format="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", + format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) @@ -272,35 +316,37 @@ def main(): # get global_rank and local rank for current process global_rank = accelerator.process_index local_rank = accelerator.local_process_index - print(f'world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}') + print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") # TASK2ID, ID2TASK # generate_task_id(args.data_paths) # multi task blendable dataset(sharded) if args.load_raw_dataset: - print_rank_0('> load raw jsonl dataset') + print_rank_0("> load raw jsonl dataset") train_dataset, valid_dataset = load_dataset_from_jsonl( - args=args, - shard_data=True, - world_size=args.world_size, - global_rank=global_rank, - local_rank=local_rank + args=args, shard_data=True, world_size=args.world_size, global_rank=global_rank, local_rank=local_rank ) else: - print_rank_0('> load tokenized bin dataset, refer to gpt_neox indexed dataset') + print_rank_0("> load tokenized bin dataset, refer to gpt_neox indexed dataset") train_dataset, valid_dataset, _ = load_dataset_from_bin(args=args) t1 = time.time() logger.info(f"dataset loading time: {t1 - t0:.4f}") # cuda memory - free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3) + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) max_memory = f"{free_in_GB - 2}GB" n_gpus = torch.cuda.device_count() max_memory = {i: max_memory for i in range(n_gpus)} accelerator.print("max memory: ", max_memory, n_gpus) + # target_modules, default all-linear for all linear layers + if args.target_modules: + target_modules = args.target_modules + else: + target_modules = "all-linear" + # peft config if args.peft_type: peft_config = LoraConfig( @@ -309,7 +355,8 @@ def main(): r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, - target_modules=args.target_modules if args.target_modules else FULL_LORA_TARGETING_MODULES[args.model_type] + target_modules=target_modules, + bias="lora_only", ) # # 是否要加入新的special tokens @@ -319,38 +366,40 @@ def main(): # creating base model ModelClass = MODEL_TYPES[args.model_type] - if args.model_type in SUPPORT_FA2_IN_TRANSFORMERS: - accelerator.print(f"[INFO] Model Type {args.model_type} is supported FA2 by Transformers and we use it") + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") model = ModelClass.from_pretrained( args.pretrained_model_path, attn_implementation=args.attn_implementation, - # trust_remote_code=True, - load_in_8bit=(args.quantization == '8bit'), - load_in_4bit=(args.quantization == '4bit'), torch_dtype=torch.bfloat16, - # low_cpu_mem_usage=args.low_cpu_mem_usage, # not for zero3 - # use_safetensors=False, - quantization_config=BitsAndBytesConfig( - load_in_4bit=(args.quantization == '4bit'), - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - ) if args.quantization == '4bit' else None, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), ) else: - accelerator.print(f"[INFO] Model Type {args.model_type} is NOT supported officially by Transformers " - f"and we use published modeling_xxx.py(may be modified by us)") + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") model = ModelClass.from_pretrained( args.pretrained_model_path, - load_in_8bit=(args.quantization == '8bit'), - load_in_4bit=(args.quantization == '4bit'), torch_dtype=torch.bfloat16, - quantization_config=BitsAndBytesConfig( - load_in_4bit=(args.quantization == '4bit'), - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - ) if args.quantization == '4bit' else None, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), ) # build a tokenizer for possible resizing or saving @@ -360,29 +409,29 @@ def main(): # 如果新增special tokens, 需要resize input embedding 和output embedding # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) - accelerator.print("load in 8bit: ", args.quantization == '8bit') - accelerator.print("load in 4bit: ", args.quantization == '4bit') - if args.peft_type: - if args.peft_type == 'lora': - model.gradient_checkpointing_enable() - # args.saving_limit = None - - elif args.peft_type == 'qlora': - # prepare base model for 8bit or 4bit model(cast non-8bit or non-4bit layers to fp32) - model = prepare_model_for_kbit_training(model) - logging.info(f"device map: {model.hf_device_map}") - # args.saving_limit = None + accelerator.print("Model load_in_4bit: ", args.quantization == "4bit") + + if args.peft_type == "lora": + model.gradient_checkpointing_enable() + elif args.peft_type == "qlora": + # prepare base model for 4bit model(cast non-4bit layers to fp32) + model = prepare_model_for_kbit_training(model) + # logging.info(f"device map: {model.hf_device_map}") else: model.gradient_checkpointing_enable() - assert (args.saving_limit is not None and isinstance(args.saving_limit, int)), "saving_limit must be a integer in Full Training" + if args.saving_limit is None or not isinstance(args.saving_limit, int) or args.saving_limit < 1: + # saving_limit is set automatically if needed + args.saving_limit = 2 + accelerator.print( + "[WARNING]saving_limit must be a integer greater than 1 in Full-Parameters Training, we set it to 2" + ) - # Potentially load in the lora from a previous save + # Load PeftModel from a previous save or create a new PeftModel if args.peft_type: if not args.resume_from_checkpoint: model = get_peft_model(model, peft_config) else: - - accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") + accelerator.print(f"[INFO] Resumed from checkpoint: {args.resume_from_checkpoint}") # accelerator.load_state(args.resume_from_checkpoint) model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True) @@ -393,38 +442,56 @@ def main(): logging.info(f"model loading time: {t2 - t1:.4f}") model.config.use_cache = False # silence the warnings. Please re-enable for inference! - model.config.use_logn_attn = False # special for qwen model + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model + # load balance for moe training + if hasattr(model.config, "output_router_logits"): + model.config.output_router_logits = True + model_config = model.config accelerator.print(model.config) # dataloader train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=DataCollatorForMFTDataset(args), - batch_size=args.per_device_train_batch_size, pin_memory=True, drop_last=True - ) - valid_dataloader = DataLoader( - valid_dataset, collate_fn=DataCollatorForMFTDataset(args), batch_size=args.per_device_eval_batch_size, - pin_memory=True, drop_last=True + train_dataset, + shuffle=True, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_train_batch_size, + pin_memory=True, + drop_last=True, ) + if valid_dataset: + valid_dataloader = DataLoader( + valid_dataset, + collate_fn=DataCollatorForMFTDataset(args), + batch_size=args.per_device_eval_batch_size, + pin_memory=True, + drop_last=True, + ) + else: + valid_dataloader = None + + # optimizer if accelerator.distributed_type == DistributedType.DEEPSPEED: accelerator.print("DISTRIBUTED TRAINING USING DEEPSPEED") - from deepspeed.ops.adam import FusedAdam as Adam - adam_optimizer = Adam + # from deepspeed.ops.adam import FusedAdam as Adam + # adam_optimizer = Adam + adam_optimizer = torch.optim.AdamW elif accelerator.distributed_type == DistributedType.FSDP: accelerator.print("DISTRIBUTED TRAINING USING FSDP") if args.peft_type and getattr(accelerator.state, "fsdp_plugin", None) is not None: from peft.utils.other import fsdp_auto_wrap_policy + accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) model = accelerator.prepare(model) adam_optimizer = torch.optim.AdamW else: - accelerator.print(f"DISTRIBUTED TRAINING USING {accelerator.distributed_type}") - adam_optimizer = torch.optim.AdamW + raise ValueError("Only support DeepSpeed and FSDP") optimizer = adam_optimizer( model.parameters(), weight_decay=args.weight_decay, lr=args.learning_rate, - betas=(0.9, 0.95), + betas=(0.9, 0.999), ) # Scheduler and math around the number of training steps. @@ -433,31 +500,42 @@ def main(): if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - + if isinstance(args.num_warmup_steps, float) and args.num_warmup_steps < 1.0: + args.num_warmup_steps = int(args.max_train_steps * args.num_warmup_steps) // accelerator.num_processes + accelerator.print(f"num_warmup_steps: {args.num_warmup_steps}") lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + # scheduler_specific_kwargs={"last_epoch": scheduler_last_ep} ) + # prepare all if accelerator.distributed_type == DistributedType.DEEPSPEED: - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare( - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler - ) + if valid_dataloader: + (model, train_dataloader, valid_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, valid_dataloader, optimizer, lr_scheduler + ) + else: + (model, train_dataloader, optimizer, lr_scheduler) = accelerator.prepare( + model, train_dataloader, optimizer, lr_scheduler + ) + + # prepare all except model, which is prepared before elif accelerator.distributed_type == DistributedType.FSDP: - optimizer, train_dataloader, valid_dataloader, lr_scheduler = accelerator.prepare( - optimizer, train_dataloader, valid_dataloader, lr_scheduler - ) - else: - # may be not suitable for all DistributedType, expected to be ok with simple multi-gpu - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare( - model, train_dataloader, valid_dataloader, optimizer, lr_scheduler - ) + if valid_dataloader: + (optimizer, train_dataloader, valid_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, valid_dataloader, lr_scheduler + ) + else: + (optimizer, train_dataloader, lr_scheduler) = accelerator.prepare( + optimizer, train_dataloader, lr_scheduler + ) print(model.device) accelerator.print(model) # accelerator.print(model.config) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. + # Recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -472,17 +550,21 @@ def main(): elif getattr(accelerator.state, "fsdp_plugin", None): accelerator.print(f"FSDP plugin: {accelerator.state.fsdp_plugin}") - # Train! - accelerate_train(accelerator, - model, - train_dataloader, - valid_dataloader, - optimizer, - lr_scheduler, - tokenizer, - num_update_steps_per_epoch, - len(train_dataset), - args) + trainer = MftTrainer( + accelerator=accelerator, + model=model, + model_config=model_config, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + tokenizer=tokenizer, + num_update_steps_per_epoch=num_update_steps_per_epoch, + total_train_dataset_size=len(train_dataset), + args=args, + ) + trainer.accelerate_train() + logger.info(f"Training Finished!") if __name__ == "__main__": diff --git a/mftcoder_accelerate/src/pefts/arguments.py b/mftcoder_accelerate/src/pefts/mft_arguments.py similarity index 79% rename from mftcoder_accelerate/src/pefts/arguments.py rename to mftcoder_accelerate/src/pefts/mft_arguments.py index 5317e5b..9fee1cd 100644 --- a/mftcoder_accelerate/src/pefts/arguments.py +++ b/mftcoder_accelerate/src/pefts/mft_arguments.py @@ -2,15 +2,15 @@ # @author Chaoyu Chen # @date 2023/10/19 -accelerate + deepspeed zero stage2 + Data Parallelism -MFT Training +training arguments """ + from dataclasses import dataclass, asdict from typing import List, Union @dataclass -class TrainArgs: +class MftTrainArgs: # train data paths on shared FS data_paths: Union[str, List[str]] @@ -47,14 +47,14 @@ class TrainArgs: # sft or sst tokenize_mode: str = "sft" - # case3 or case4 + # mft loss mode weighted_loss_mode: str = "case3" # lora or qlora or None(for full-parameter training) - peft_type: str = "qlora" + peft_type: Union[None, str] = "qlora" - # if qlora, 4bit or 8bit, 4bit is suggested - quantization: str = "4bit" + # if qlora, 4bit will be set, else None + quantization: Union[None, str] = "4bit" # lora rank, the bigger, the more trainalbe parameters lora_rank: int = 96 @@ -66,7 +66,7 @@ class TrainArgs: lora_dropout: float = 0.05 # lora targeting modules - target_modules: Union[None, List[str]] = None + target_modules: Union[None, str, List[str]] = None # mircro train batch size per_device_train_batch_size: int = 8 @@ -84,7 +84,7 @@ class TrainArgs: min_lr: float = 5e-6 # weight decay - weight_decay: float = 0.1 + weight_decay: float = 0.01 # gradient_accumulation_steps gradient_accumulation_steps: int = 1 @@ -93,7 +93,7 @@ class TrainArgs: lr_scheduler_type: str = "cosine" # num_warmup_steps - num_warmup_steps: int = 300 + num_warmup_steps: Union[int, float] = 0.05 # num_train_epochs num_train_epochs: int = 4 @@ -107,6 +107,9 @@ class TrainArgs: # path of adaptor which is resumed from, None for not resuming training resume_from_checkpoint: Union[None, str] = None + # auto resume from latest ckpt if job restarted + auto_resume: bool = True + # num of steps for logging training loss log_interval: int = 10 @@ -128,24 +131,26 @@ class TrainArgs: # DDP random sampler use_random_sampler: bool = True - # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point + # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point early_stopping: bool = True early_stopping_stall_num: int = 5 # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota. saving_limit: Union[None, int] = None - # if dynamic padding + # if dynamic padding use_dynamic_padding: bool = True - # interval of update per task train weight in selfpaced - selfpaced_interval: int = 1 - # history length of sample valid loss used to fit the slope curve in selfpaced - selfpaced_history_length: int = 100 - # the number of mini valid batches sampled at each interval - selfpaced_sample_valid_num: int = 1 - # scale factor before softmax - selfpaced_scale_factor: int = 50 + # warm-up steps for CoBa, recommand the number of valid batches + coba_warmup_steps: int = 100 + # history length of sample valid loss used to fit the slope curve in CoBa, recommand [2*coba_warmup_steps,5*coba_warmup_steps] + coba_history_length: int = 200 + # temperature for divergence factor in CoBa + coba_tau: int = 5 + # iteration interval of update per task train weight in CoBa + coba_update_interval: int = 1 + # the number of mini valid batches sampled at each updated iteration interval + coba_sample_valid_num: int = 1 # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} attn_implementation: str = "flash_attention_2" @@ -154,7 +159,10 @@ class TrainArgs: # role_markers: {"system": "### System:\n", "user": "### Instruction:\n", "assistant": "### Response:\n"} role_markers: Union[None, dict] = None - distributed_type: Union[None, str] = "deepspeed" + distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + # legacy, leave them use_xformers: bool = True trust_remote_code: bool = True diff --git a/mftcoder_accelerate/src/pefts/mft_trainer.py b/mftcoder_accelerate/src/pefts/mft_trainer.py new file mode 100644 index 0000000..a2b00fb --- /dev/null +++ b/mftcoder_accelerate/src/pefts/mft_trainer.py @@ -0,0 +1,606 @@ +""" +# @author qumu +# @date 2024/4/12 +# @module trainer.py + +Accelerate + DeepSpeed/FSDP +QLoRA/LoRA/Full + SFT/MFT + +Trainer +""" + +import gc +import os +import sys +import threading +import argparse +import math +import logging +import json +import time +import transformers +import numpy as np +import psutil +import shutil +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from typing import List, Optional, Tuple, Union +from tqdm.auto import tqdm +from accelerate.logging import get_logger +from accelerate import Accelerator +from transformers import set_seed + +# sys.path.append("..") +from utils.common_utils import generate_task_id, TASK2ID, ID2TASK +from utils.loss_utils import loss_func_mft, CoBaStatus, load_balancing_loss_func + +logger = get_logger(__name__) + + +def copy_tokenizer_files(mode_path: str, files_list: List[str], save_path: str): + # create path if not exist + if not os.path.exists(save_path): + os.makedirs(save_path) + + # copy each file in files_list to save_path + for filename in files_list: + src_file = os.path.join(mode_path, filename) + + # copy only if src exists + if os.path.exists(src_file): + dest_file = os.path.join(save_path, filename) + + # copy + shutil.copy(src_file, dest_file) + print(f"Copied {filename} to {save_path}") + else: + print(f"File {filename} does not exist in {mode_path}") + + +def check_existing_ckpts(output_dir): + prefix = "step_" + + if not os.path.exists(output_dir): + return [] + # list all files and dirs + contents = os.listdir(output_dir) + + # find dirs starts with "step_" + matching_folders = [ + folder for folder in contents if os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix) + ] + + return matching_folders + + +def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps): + """ + extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training + """ + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + logger.info(f"Resume from exact Epoch {starting_epoch}: completed_steps {completed_steps}") + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + completed_steps = int(training_difference.replace("step_", "")) + starting_epoch = completed_steps // num_update_steps_per_epoch + resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps + logger.info(f"Resume from Epoch {starting_epoch} + step {resume_step}: completed_steps {completed_steps}") + + return starting_epoch, completed_steps, resume_step + + +def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): + for key, value in log_dict.items(): + summary_writer.add_scalar(f"{key}", value, completed_steps) + + +def delete_ckpts_over_limits(output_dir, saving_limit, best_step): + """delete ckpts more than saving_limits except for the best_step ckpt""" + existing_ckpts = check_existing_ckpts(output_dir) + logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}") + # sorted only step num ascendingly + ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts]) + # delete the oldest steps except for the best step at present + if len(ckpt_steps) > saving_limit: + deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step] + # print(deletable_steps[:len(ckpt_steps) - saving_limit]) + for del_step in deletable_steps[: len(ckpt_steps) - saving_limit]: + shutil.rmtree(os.path.join(output_dir, f"step_{del_step}")) + logger.info(f"Removed ckpt step_{del_step}") + + +class MftTrainer: + """ + Multitask FineTuing Trainer, supporting MFT/SFT/ContinueTrain with Lora/Qlora/Full-parameters. + """ + + def __init__( + self, + accelerator: Accelerator, + model, + model_config, + train_dataloader, + valid_dataloader, + optimizer, + lr_scheduler, + tokenizer, + num_update_steps_per_epoch, + total_train_dataset_size, + args, + ): + self.accelerator = accelerator + self.model = model + # hf model config + self.model_config = model_config + self.train_dataloader = train_dataloader + self.valid_dataloader = valid_dataloader + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.tokenizer = tokenizer + self.num_update_steps_per_epoch = num_update_steps_per_epoch + self.total_train_dataset_size = total_train_dataset_size + # training arguments + self.args = args + # tensorboard writer + self.summary_writer = SummaryWriter(log_dir=args.tb_dir) + + def print(self, msg: str): + """ + accelerator print, default on main process + Args: + msg: + + Returns: + + """ + self.accelerator.print(msg) + + def touch(self, batch, num_tokens=10): + """touch first and last tokens and labels for debugging usage""" + self.print( + f"step 1 batch shape: {batch['input_ids'].shape},\n" + f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" + f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}" + ) + self.print(f"first {num_tokens} input_ids and loss_mask") + for pt in range(1): + self.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + self.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") + + @staticmethod + def format_tensor(tensor, n): + return list(map(lambda x: round(x, n), tensor.tolist())) + + def accelerate_saving_checkpoint(self, output_dir: str, completed_steps: int): + """ + Saving lora adaptor or full checkpoint using accelerator + Args: + output_dir: exact dir for saving ckpt + completed_steps: + + Returns: + + """ + self.accelerator.wait_for_everyone() + + logger.info(f"[CHECKPOINT] Saving checkpoint", main_process_only=True) + unwrapped_model = self.accelerator.unwrap_model(self.model) + # self.print(f"unwrapped model type {type(unwrapped_model)}") + unwrapped_model.save_pretrained( + output_dir, + is_main_process=self.accelerator.is_main_process, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + self.accelerator.wait_for_everyone() + # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge. + if not self.args.peft_type and self.accelerator.is_main_process: + if self.args.model_type.lower() == "deepseek": + copy_tokenizer_files( + self.args.pretrained_model_path, ["tokenizer.json", "tokenizer_config.json"], output_dir + ) + else: + self.tokenizer.save_pretrained(output_dir) + + sf = os.path.join(output_dir, "model.safetensors") + index_file = os.path.join(output_dir, "model.safetensors.index.json") + if os.path.isfile(sf) and os.path.isfile(index_file): + self.print(f"Remove bug dummy ckpt {sf}") + os.remove(sf) + + if self.accelerator.is_main_process: + latest = { + "latest_ckpt": output_dir, + "lr": self.optimizer.param_groups[0]["lr"], + } + with open(os.path.join(self.args.output_dir, "latest"), "w") as f: + json.dump(latest, f, indent=2) + + logger.info( + f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved, latest: {latest}", + main_process_only=True, + ) + self.accelerator.wait_for_everyone() + + def accelerate_monitor( + self, + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status=None, + ): + """ + gather reduce_loss and reduce_task_loss from all N devices. + train logging and tensorboarding. + """ + # gather reduce_loss and reduce_task_loss from all N devices + reduce_losses = self.accelerator.gather(reduce_loss).detach().float() + reduce_task_losses = self.accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK)) + reduce_task_exists = self.accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK)) + + # get train loss and per-task train loss + train_loss = torch.mean(reduce_losses) / (self.args.log_interval * self.args.gradient_accumulation_steps) + # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (self.args.log_interval * self.args.gradient_accumulation_steps) + train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0) + + # logging and writing tensorboard + logger.info( + f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]" + f"[train_task_loss={self.format_tensor(train_task_loss, 4)}]" + f"[gather shape={list(reduce_losses.shape)}]" + f"[lr={self.lr_scheduler.get_lr()[0]:.4e}, {self.optimizer.param_groups[0]['lr']:.4e}]", + main_process_only=True, + ) + if coba_status is not None: + if completed_steps > coba_status.coba_warmup_steps: + coba_status.log_per_task_weight = coba_status.log_per_task_weight / torch.sum( + coba_status.log_per_task_weight + ) + else: + coba_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) + logger.info( + f"[TRAIN][per_task_train_weight={coba_status.log_per_task_weight}]", main_process_only=True + ) + train_log_dict = {"Loss/train": train_loss} + for i in range(len(ID2TASK)): + train_log_dict[f"{ID2TASK[i]}_loss/train"] = train_task_loss[i] + if coba_status is not None: + train_log_dict[f"{ID2TASK[i]}_coba_weight/train"] = coba_status.log_per_task_weight[i].item() + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, train_log_dict, completed_steps) + + if coba_status is not None: + coba_status.log_per_task_weight = torch.zeros(len(ID2TASK)) + + def accelerate_evaluate( + self, + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ): + """ + evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices. + eval logging and tensorboarding. + """ + losses = [] + accumulated_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + accumulated_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + for valid_step, valid_batch in enumerate(self.valid_dataloader): + with torch.no_grad(): + outputs = self.model( + input_ids=valid_batch["input_ids"], + attention_mask=valid_batch["attention_mask"], + position_ids=valid_batch["position_ids"], + return_dict=True, + ) + + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=valid_batch["labels"], + task_mask=valid_batch["task_mask"], + task_id=valid_batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=valid_batch["loss_mask"], + task_weights=self.args.task_weights, + ) + + losses.append(self.accelerator.gather(loss.repeat(self.args.per_device_eval_batch_size))) + accumulated_task_loss += task_loss.detach().float() + accumulated_task_exist += (task_loss != 0.0).detach().float() + + self.accelerator.wait_for_everyone() + valid_batch_num = len(losses) + gathered_size = losses[0].shape + losses = torch.cat(losses) + # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK)) + task_losses = self.accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK)) + task_exists = self.accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK)) + + try: + eval_loss = torch.mean(losses) + # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num + eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0) + if eval_loss <= min_eval_loss: + min_eval_loss = eval_loss + stall_num = 0 + best_step = completed_steps + else: + stall_num += 1 + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info( + f"[EVAL][completed_steps={completed_steps}]" + f"[eval_loss={eval_loss:.6f}][eval_task_loss={self.format_tensor(eval_task_loss, 4)}]" + f"[perplexity={perplexity:.4f}][valid_batch_num={valid_batch_num}]" + f"[gather_size={list(gathered_size)}]", + main_process_only=True, + ) + eval_log_dict = { + "Loss/valid": eval_loss, + "Perplexity/valid": perplexity, + "Epochs": round(completed_steps / self.num_update_steps_per_epoch, 2), + } + for i in range(len(ID2TASK)): + eval_log_dict[f"{ID2TASK[i]}_loss/valid"] = eval_task_loss[i] + + if self.accelerator.is_main_process: + write_tensorboard(self.summary_writer, eval_log_dict, completed_steps) + + return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step + + def accelerate_train(self): + # Train! + if self.args.seed is not None: + set_seed(self.args.seed) + + global_batch_size = ( + self.args.per_device_train_batch_size + * self.accelerator.num_processes + * self.args.gradient_accumulation_steps + ) + logger.info("************************************** Running training ****************************************") + logger.info(f" Num examples = {self.total_train_dataset_size}") + logger.info(f" Num Epochs = {self.args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") + logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {global_batch_size}") + logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") + logger.info(f" Total optimization(update/completed) steps = {self.args.max_train_steps}") + logger.info(f" Complete/optimize steps per Epoch = {self.args.max_train_steps // self.args.num_train_epochs}") + logger.info("************************************************************************************************") + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process) + + # set starting_epoch, completed_steps and resume_step of train_dataloader + completed_steps = 0 + starting_epoch = 0 + resume_step = None + + if self.args.resume_from_checkpoint: + path = os.path.basename(self.args.resume_from_checkpoint) + starting_epoch, completed_steps, resume_step = extract_epochs_and_steps( + path, self.num_update_steps_per_epoch, self.args.gradient_accumulation_steps + ) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # monitor minimum eval_loss, stalling num, and best_step + min_eval_loss = float("inf") + stall_num = 0 + best_step = None + + # monitor train loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + per_task_weight = self.args.task_weights + + if self.args.weighted_loss_mode == "coba": + self.model.eval() + eval_loss, eval_task_loss, _, _, _ = self.accelerate_evaluate( + completed_steps, + 0, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + coba_status = CoBaStatus( + self.args.coba_warmup_steps, + self.args.coba_history_length, + self.args.coba_tau, + self.args.coba_update_interval, + self.args.coba_sample_valid_num, + self.valid_dataloader, + ) + coba_status.valid_task_loss_begining = eval_task_loss.clone().to(self.model.device) + coba_status.sample_valid_batch(self.model, completed_steps) + logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + else: + coba_status = None + + # Training Loop! + for epoch in range(starting_epoch, self.args.num_train_epochs): + # set_epoch + # self.train_dataloader.set_epoch(epoch) + + # if we early stop by some ckpts not converging + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + break + + if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step) + else: + active_dataloader = self.train_dataloader + tail_num = len(active_dataloader) - len(active_dataloader) % self.args.gradient_accumulation_steps + print(f"length of dataloader: {len(active_dataloader)}") + + self.model.train() + # Inner Loop! + for step, batch in enumerate(active_dataloader): + if step == tail_num: + break + with self.accelerator.accumulate(self.model): + if step == 0: + self.touch(batch, num_tokens=10) + # forward + outputs = self.model( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + position_ids=batch["position_ids"], + return_dict=True, + ) + + if ( + self.args.weighted_loss_mode == "coba" + and self.accelerator.sync_gradients + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps >= self.args.coba_warmup_steps + ): + with torch.no_grad(): + per_task_weight = coba_status.compute_per_task_weight(completed_steps=completed_steps) + coba_status.log_per_task_weight += per_task_weight + # logger.info(f'per_task_weight: {per_task_weight}', main_process_only=True) + + # loss + loss, task_loss, _ = loss_func_mft( + outputs=outputs, + labels=batch["labels"], + task_mask=batch["task_mask"], + task_id=batch["task_id"], + weighted_loss_mode=self.args.weighted_loss_mode, + loss_mask=batch["loss_mask"], + task_weights=per_task_weight, + ) + + # accelerator.print(len(outputs.router_logits), outputs.router_logits[0], outputs.router_logits[-1]) + # accelerator.print(batch['attention_mask'].shape, batch['attention_mask']) + aux_loss = None + if hasattr(self.model_config, "output_router_logits") and self.model_config.output_router_logits: + if hasattr(self.model_config, "num_local_experts"): + num_experts = self.model_config.num_local_experts + elif hasattr(self.model_config, "num_experts"): + num_experts = self.model_config.num_experts + else: + raise ValueError("model has no attribute num_local_experts or num_experts") + aux_loss = load_balancing_loss_func( + outputs.router_logits, + num_experts, + self.model_config.num_experts_per_tok, + batch["attention_mask"], + ) + aux_loss = self.model_config.router_aux_loss_coef * aux_loss.to(loss.device) + loss += aux_loss # make sure to reside in the same device + + # backward + self.accelerator.backward(loss) + # print(self.lr_scheduler.state_dict(), self.accelerator.process_index) + # update(sync_gradients) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + # support args.min_lr + if self.optimizer.param_groups[0]["lr"] <= self.args.min_lr: + self.optimizer.param_groups[0]["lr"] = self.args.min_lr + + # accumulate resuce_loss and reduce_task_loss in a log_interval + if not torch.isnan(loss): + reduce_loss += loss.detach().float() + if aux_loss and not torch.isnan(aux_loss): + reduce_aux_loss += aux_loss.detach().float() + # self.print("task loss devices: ", reduce_task_loss.device, task_loss.device) + reduce_task_loss += task_loss.detach().float() + reduce_task_exist += (task_loss != 0).detach().float() + + # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. + if self.accelerator.sync_gradients: + if ( + self.args.weighted_loss_mode == "coba" + and completed_steps % self.args.coba_update_interval == 0 + and completed_steps >= 1 + ): + coba_status.sample_valid_batch(self.model, completed_steps) + # logger.info(f"valid_task_loss: {coba_status.valid_task_loss_accumulated}", main_process_only=True) + + # progress_bar.update(1) + completed_steps += 1 + # monitoring training process and logging and tensorboarding + if completed_steps % self.args.log_interval == 0: + progress_bar.update(self.args.log_interval) + if reduce_aux_loss > 0.0: + self.print(f"[INFO] aux_loss: {reduce_aux_loss/self.args.log_interval}") + self.accelerate_monitor( + reduce_loss, + reduce_task_loss, + reduce_task_exist, + completed_steps, + coba_status, + ) + # reset reduce_loss + reduce_loss = torch.tensor(0.0).to(self.model.device) + reduce_aux_loss = torch.tensor(0.0).to(self.model.device) + reduce_task_loss = torch.zeros(len(ID2TASK)).to(self.model.device) + reduce_task_exist = torch.zeros(len(ID2TASK)).to(self.model.device) + + # steps checkpointing + if self.args.checkpointing_steps and completed_steps % self.args.checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_checkpoint(output_dir, completed_steps) + + # steps evaluation + if completed_steps % self.args.evaluation_steps == 0 and self.valid_dataloader: + self.model.eval() + eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = self.accelerate_evaluate( + completed_steps, + step, + min_eval_loss, + stall_num, + best_step, + ) + self.model.train() + + # delete ckpts over args.saving_limit + if self.accelerator.is_main_process and self.args.saving_limit: + delete_ckpts_over_limits(self.args.output_dir, self.args.saving_limit, best_step) + + # early stoppin when stalling more than args.early_stopping_stall_num + if self.args.early_stopping and stall_num == self.args.early_stopping_stall_num: + self.print(f"[WARNING] Early stopping at {completed_steps}") + break + + if completed_steps >= self.args.max_train_steps: + break + self.accelerator.wait_for_everyone() + + # epoch checkpointing + if self.args.epoch_checkpointing: + output_dir = f"epoch_{epoch + 1}" + if self.args.output_dir is not None: + output_dir = os.path.join(self.args.output_dir, output_dir) + self.accelerate_saving_checkpoint(output_dir, completed_steps) + + self.summary_writer.close() + + # final save + # output_dir = f"final_step_{completed_steps}" + # if self.args.output_dir is not None: + # output_dir = os.path.join(self.args.output_dir, output_dir) + # self.accelerate_saving_checkpoint(output_dir, completed_steps) diff --git a/mftcoder_accelerate/src/pefts/model_mapping.py b/mftcoder_accelerate/src/pefts/model_mapping.py deleted file mode 100644 index 0474a6d..0000000 --- a/mftcoder_accelerate/src/pefts/model_mapping.py +++ /dev/null @@ -1,141 +0,0 @@ -""" - # @author Chaoyu Chen - # @date 2023/12/11 - - Manage supported models and their special token used in training. - Default targeting modules for LoRA/QLora - 4.36 is stable now -""" -# Models that Transformers support FA2 -from transformers import ( - AutoConfig, - AutoTokenizer, - AutoModelForCausalLM, - GPTNeoXForCausalLM, - GPTBigCodeForCausalLM, - LlamaForCausalLM, - MistralForCausalLM, - MixtralForCausalLM, - PhiForCausalLM, -) - -# Models that Transformers not support FA2, supported by publisher or ourself -from model.aquila2.modeling_aquila import AquilaForCausalLM -from model.baichuan2.modeling_baichuan import BaichuanForCausalLM -from model.qwen.modeling_qwen import QWenLMHeadModel -from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration2 -from model.chatglm3.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration3 - -# from model.phi.modeling_mixformer_sequential import MixFormerSequentialForCausalLM - -MODEL_TYPES = { - "aquila2": AquilaForCausalLM, - "baichuan": BaichuanForCausalLM, - 'chatglm2': ChatGLMForConditionalGeneration2, - 'chatglm3': ChatGLMForConditionalGeneration3, - "code_llama": LlamaForCausalLM, - "deepseek": LlamaForCausalLM, - "gpt_neox": GPTNeoXForCausalLM, - "llama": LlamaForCausalLM, - "mistral": MistralForCausalLM, - "mixtral": MixtralForCausalLM, - 'phi': PhiForCausalLM, - 'qwen': QWenLMHeadModel, - "starcoder": GPTBigCodeForCausalLM, -} - -FULL_LORA_TARGETING_MODULES = { - "aquila": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "baichuan": ["W_pack", "o_proj", "gate_proj", "down_proj", "up_proj"], - "chatglm2": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - "chatglm3": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], - "deepseek": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "code_llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "gpt_neox": ["query_key_value", 'dense', 'dense_h_to_4h', 'dense_4h_to_h'], - "llama": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "mistral": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"], - "mixtral": ["q_proj", "k_proj", "v_proj", "o_proj"], - "phi": ["query_key_value", 'dense', 'fc1', 'fc2'], - "qwen": ["c_proj", "c_attn", "w1", "w2"], - "starcoder": ["c_proj", "c_attn", "q_attn", "c_fc"], -} - -MODEL_SPECIAL_TOKENS = { - "gpt_neox": { - - "eos_token": "<|endoftext|>", - "pad_token": "<|pad|>", - - }, - "llama": { - - "eos_token": "", - "pad_token": "", - - }, - "code_llama": { - - "eos_token": "", - "pad_token": "", - - }, - "baichuan": { - - "eos_token": "", - "pad_token": "", - - }, - "starcoder": { - - "eos_token": "<|endoftext|>", - "pad_token": "", - - }, - "qwen": { - - "eos_token": "<|endoftext|>", - "pad_token": "<|extra_1|>", - - }, - "chatglm2": { - - "eos_token": "", - "pad_token": "", - - }, - "chatglm3": { - - "eos_token": "", - "pad_token": "", - - }, - "phi": { - - "eos_token": "<|endoftext|>", - "pad_token": "<|endoftext|>", - }, - "aquila": { - - "eos_token": "", - "pad_token": "<|endoftext|>", - - }, - "deepseek": { - - "eos_token": "<|end▁of▁sentence|>", - "pad_token": "<|end▁of▁sentence|>", - - }, - "mixtral": { - - "eos_token": "", - "pad_token": "", - - }, - "mistral": { - - "eos_token": "", - "pad_token": "", - - }, -} diff --git a/mftcoder_accelerate/src/pefts/train_utils.py b/mftcoder_accelerate/src/pefts/train_utils.py deleted file mode 100644 index fb0ca1c..0000000 --- a/mftcoder_accelerate/src/pefts/train_utils.py +++ /dev/null @@ -1,416 +0,0 @@ -""" -# @author Chaoyu Chen -# @date 2023/10/19 -# @module train_utils.py - -Accelerate + DeepSpeed zero stage2 + DistributedDataParallel -QLoRA/LoRA/Full + MFT/MPT, resource and parameters efficient training - -training functions -""" - -import gc -import os -import sys -import threading -import argparse -import math -import logging -import json -import time -import transformers -import numpy as np -import psutil -import shutil -import torch -from torch import nn -from tqdm.auto import tqdm - -sys.path.append("..") -from utils.common_utils import generate_task_id, TASK2ID, ID2TASK -from utils.auto_accelerate_utils import loss_func_mft, SelfpacedStatus -from torch.utils.tensorboard import SummaryWriter -from accelerate.logging import get_logger - -logger = get_logger(__name__) - - -def check_existing_ckpts(output_dir): - prefix = "step_" - - if not os.path.exists(output_dir): - return [] - # 列出目录中的所有文件和文件夹 - contents = os.listdir(output_dir) - - # 使用列表推导式筛选以"step_"开头的文件夹 - matching_folders = [folder for folder in contents if - os.path.isdir(os.path.join(output_dir, folder)) and folder.startswith(prefix)] - - return matching_folders - - -def extract_epochs_and_steps(path, num_update_steps_per_epoch, gradient_accumulation_steps): - """ - extract starting_epoch, completed_steps, resume_step of train_dataloader for resumed training - """ - # Extract `epoch_{i}` or `step_{i}` - training_difference = os.path.splitext(path)[0] - - if "epoch" in training_difference: - starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None - completed_steps = starting_epoch * num_update_steps_per_epoch - print(f"resume from epoch {starting_epoch} and completed_steps {completed_steps}") - else: - # need to multiply `gradient_accumulation_steps` to reflect real steps - completed_steps = int(training_difference.replace("step_", "")) - starting_epoch = completed_steps // num_update_steps_per_epoch - resume_step = (completed_steps % num_update_steps_per_epoch) * gradient_accumulation_steps - print(f"resume from epoch {starting_epoch} resusme step {resume_step} and completed_steps {completed_steps}") - - return starting_epoch, completed_steps, resume_step - - -def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): - for key, value in log_dict.items(): - summary_writer.add_scalar(f'{key}', value, completed_steps) - - -def accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir: str, completed_steps: int, args): - """ - Saving lora adaptor or full checkpoint using accelerator - """ - accelerator.wait_for_everyone() - - logger.info( - f"[CHECKPOINT] Saving checkpoint", - main_process_only=True - ) - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - output_dir, - is_main_process=accelerator.is_main_process, - save_function=accelerator.save, - state_dict=accelerator.get_state_dict(model) - ) - # for full-parameter training, save whole ckpt and tokenizer together because it does not need a merge. - if not args.peft_type and accelerator.is_main_process: - tokenizer.save_pretrained(output_dir) - - logger.info( - f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved", - main_process_only=True - ) - accelerator.wait_for_everyone() - - -def accelerate_monitor(accelerator, model, reduce_loss, reduce_task_loss, reduce_task_exist, args, completed_steps, - lr_scheduler, optimizer, summary_writer, selfpaced_status=None): - """ - gather reduce_loss and reduce_task_loss from all N devices. - train logging and tensorboarding. - """ - # gather reduce_loss and reduce_task_loss from all N devices - reduce_losses = accelerator.gather(reduce_loss).detach().float() - reduce_task_losses = accelerator.gather(reduce_task_loss).reshape(-1, len(ID2TASK)) - reduce_task_exists = accelerator.gather(reduce_task_exist).reshape(-1, len(ID2TASK)) - # get train loss and per-task train loss - train_loss = torch.mean(reduce_losses) / (args.log_interval * args.gradient_accumulation_steps) - # train_task_loss = torch.mean(reduce_task_losses, dim=0) / (args.log_interval * args.gradient_accumulation_steps) - train_task_loss = torch.sum(reduce_task_losses, dim=0) / torch.sum(reduce_task_exists, dim=0) - - # logging and tensorboard - logger.info( - f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}][train_task_loss={train_task_loss}]" - f"[gather shape={reduce_losses.shape}][lr={lr_scheduler.get_lr()[0]:.4e}, {optimizer.param_groups[0]['lr']:.4e}]", - main_process_only=True) - if selfpaced_status is not None: - if completed_steps > selfpaced_status.selfpaced_history_length: - selfpaced_status.log_per_task_weight = selfpaced_status.log_per_task_weight / torch.sum(selfpaced_status.log_per_task_weight) - else: - selfpaced_status.log_per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) - logger.info(f"[TRAIN][per_task_train_weight={selfpaced_status.log_per_task_weight}]", main_process_only=True) - train_log_dict = {"training_loss": train_loss} - for i in range(len(ID2TASK)): - train_log_dict[f"{ID2TASK[i]}_train_loss"] = train_task_loss[i] - if selfpaced_status is not None: - train_log_dict[f"{ID2TASK[i]}_train_selfpaced_weight"] = selfpaced_status.log_per_task_weight[i].item() - - if accelerator.is_main_process: - write_tensorboard(summary_writer, train_log_dict, completed_steps) - - if selfpaced_status is not None: - selfpaced_status.log_per_task_weight = torch.zeros(len(ID2TASK)) - - -def accelerate_evaluate(accelerator, model, valid_dataloader, args, completed_steps, step, min_eval_loss, stall_num, - best_step, summary_writer): - """ - evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices. - eval logging and tensorboarding. - """ - losses = [] - accumulated_task_loss = torch.zeros(len(ID2TASK)).to(model.device) - accumulated_task_exist = torch.zeros(len(ID2TASK)).to(model.device) - for valid_step, valid_batch in enumerate(valid_dataloader): - with torch.no_grad(): - outputs = model( - input_ids=valid_batch['input_ids'], - attention_mask=valid_batch['attention_mask'], - position_ids=valid_batch['position_ids'], - return_dict=True, - ) - - loss, task_loss, _ = loss_func_mft( - outputs=outputs, - labels=valid_batch['labels'], - task_mask=valid_batch['task_mask'], - task_id=valid_batch['task_id'], - weighted_loss_mode=args.weighted_loss_mode, - loss_mask=valid_batch['loss_mask'], - task_weights=args.task_weights - ) - - losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) - accumulated_task_loss += task_loss.detach().float() - accumulated_task_exist += (task_loss != 0.0).detach().float() - - accelerator.wait_for_everyone() - valid_batch_num = len(losses) - gathered_size = losses[0].shape - losses = torch.cat(losses) - # task_losses = torch.cat(task_losses).reshape(-1, len(ID2TASK)) - task_losses = accelerator.gather(accumulated_task_loss).reshape(-1, len(ID2TASK)) - task_exists = accelerator.gather(accumulated_task_exist).reshape(-1, len(ID2TASK)) - - try: - eval_loss = torch.mean(losses) - # eval_task_loss = torch.mean(task_losses, dim=0) / valid_batch_num - eval_task_loss = torch.sum(task_losses, dim=0) / torch.sum(task_exists, dim=0) - if eval_loss <= min_eval_loss: - min_eval_loss = eval_loss - stall_num = 0 - best_step = completed_steps - else: - stall_num += 1 - perplexity = math.exp(eval_loss) - except OverflowError: - perplexity = float("inf") - - logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]" - f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]" - f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]" - f"[eval_task_loss={eval_task_loss}]", - main_process_only=True) - eval_log_dict = {"valid_loss": eval_loss, "perplexity": perplexity} - for i in range(len(ID2TASK)): - eval_log_dict[f"{ID2TASK[i]}_valid_loss"] = eval_task_loss[i] - - if accelerator.is_main_process: - write_tensorboard(summary_writer, eval_log_dict, completed_steps) - - return eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step - - -def delete_ckpts_over_limits(output_dir, saving_limit, best_step): - """delete ckpts more than saving_limits except for the best_step ckpt""" - existing_ckpts = check_existing_ckpts(output_dir) - logger.info(f"Existing step ckpts folders: {existing_ckpts}, best step ckpt: step_{best_step}") - # sorted only step num ascendingly - ckpt_steps = sorted([int(ckpt.replace("step_", "")) for ckpt in existing_ckpts]) - # delete the oldest steps except for the best step at present - if len(ckpt_steps) > saving_limit: - deletable_steps = [ckpt_step for ckpt_step in ckpt_steps if ckpt_step != best_step] - # print(deletable_steps[:len(ckpt_steps) - saving_limit]) - for del_step in deletable_steps[:len(ckpt_steps) - saving_limit]: - shutil.rmtree(os.path.join(output_dir, f"step_{del_step}")) - logger.info(f"Removed ckpt step_{del_step}") - - -def touch_print(accelerator, batch, num_tokens=10): - """touch first and last tokens and labels for debugging usage""" - accelerator.print(f"step 1 batch shape: {batch['input_ids'].shape},\n" - f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" - f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}") - accelerator.print(f"first {num_tokens} input_ids and loss_mask") - for pt in range(1): - accelerator.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") - accelerator.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") - - -def accelerate_train(accelerator, model, train_dataloader, valid_dataloader, optimizer, lr_scheduler, tokenizer, - num_update_steps_per_epoch, total_train_dataset_size, args): - # tensorboard writer - summary_writer = SummaryWriter(log_dir=args.tb_dir) - # Train! - total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - logger.info("**************************************** Running training ****************************************") - logger.info(f" Num examples = {total_train_dataset_size}") - logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization(update/completed) steps = {args.max_train_steps}") - logger.info(f" Complete/Optimization steps per Epoch = {args.max_train_steps // args.num_train_epochs}") - logger.info("***************************************************************************************************") - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - - # set starting_epoch, completed_steps and resume_step of train_dataloader - completed_steps = 0 - starting_epoch = 0 - resume_step = None - - if args.resume_from_checkpoint: - path = os.path.basename(args.resume_from_checkpoint) - starting_epoch, completed_steps, resume_step = extract_epochs_and_steps( - path, num_update_steps_per_epoch, args.gradient_accumulation_steps - ) - - # update the progress_bar if load from checkpoint - progress_bar.update(completed_steps) - - # monitor minimum eval_loss, stalling num, and best_step - min_eval_loss = float('inf') - stall_num = 0 - best_step = None - - # monitor train loss - reduce_loss = 0 - reduce_task_loss = torch.zeros(len(ID2TASK)).to(model.device) - reduce_task_exist = torch.zeros(len(ID2TASK)).to(model.device) - per_task_weight = args.task_weights - - if args.weighted_loss_mode == "selfpaced": - selfpaced_status = SelfpacedStatus(args.selfpaced_scale_factor, args.selfpaced_interval, args.selfpaced_history_length, args.selfpaced_sample_valid_num, valid_dataloader) - selfpaced_status.sample_valid_batch(model, completed_steps) - selfpaced_status.valid_iterator = iter(selfpaced_status.valid_dataloader) - else: - selfpaced_status = None - - # Training Loop! - for epoch in range(starting_epoch, args.num_train_epochs): - if args.early_stopping and stall_num == args.early_stopping_stall_num: - break - - if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: - # We skip the first `n` batches in the dataloader when resuming from a checkpoint - active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) - else: - active_dataloader = train_dataloader - tail_num = len(active_dataloader) - len(active_dataloader) % args.gradient_accumulation_steps - print(f"length of dataloader: {len(active_dataloader)}") - - model.train() - # Inner Loop! - for step, batch in enumerate(active_dataloader): - if step == tail_num: - break - with accelerator.accumulate(model): - if step == 0: - touch_print(accelerator, batch, num_tokens=10) - # forward - outputs = model( - input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - position_ids=batch['position_ids'], - return_dict=True - ) - - if args.weighted_loss_mode == 'selfpaced' and step % args.gradient_accumulation_steps == 0 and completed_steps % args.selfpaced_interval == 0 and completed_steps >= args.selfpaced_history_length: - per_task_weight = selfpaced_status.compute_per_task_weight(completed_steps=completed_steps) - selfpaced_status.log_per_task_weight += per_task_weight - - # loss - loss, task_loss, _ = loss_func_mft( - outputs=outputs, - labels=batch['labels'], - task_mask=batch['task_mask'], - task_id=batch['task_id'], - weighted_loss_mode=args.weighted_loss_mode, - loss_mask=batch['loss_mask'], - task_weights=per_task_weight - ) - - # backward - accelerator.backward(loss) - - # update(sync_gradients) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - # support args.min_lr - if optimizer.param_groups[0]['lr'] <= args.min_lr: - optimizer.param_groups[0]['lr'] = args.min_lr - - # accumulate resuce_loss and reduce_task_loss in a log_interval - if not torch.isnan(loss): - reduce_loss += loss.detach().float() - # accelerator.print("task loss devices: ", reduce_task_loss.device, task_loss.device) - reduce_task_loss += task_loss.detach().float() - reduce_task_exist += (task_loss != 0).detach().float() - - # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. - if accelerator.sync_gradients: - if args.weighted_loss_mode == 'selfpaced' and completed_steps % args.selfpaced_interval == 0 and completed_steps >= 1: - selfpaced_status.sample_valid_batch(model, completed_steps) - - # progress_bar.update(1) - completed_steps += 1 - # monitoring training process and logging and tensorboarding - if completed_steps % args.log_interval == 0: - progress_bar.update(args.log_interval) - accelerate_monitor( - accelerator, model, reduce_loss, reduce_task_loss, reduce_task_exist, args, completed_steps, - lr_scheduler, optimizer, summary_writer, selfpaced_status - ) - # reset reduce_loss - reduce_loss = 0 - reduce_task_loss = torch.zeros(len(ID2TASK)).to(model.device) - reduce_task_exist = torch.zeros(len(ID2TASK)).to(model.device) - - # steps checkpointing - if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: - output_dir = f"step_{completed_steps}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args) - - # steps evaluation - if completed_steps % args.evaluation_steps == 0: - model.eval() - eval_loss, eval_task_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate( - accelerator, model, valid_dataloader, args, completed_steps, step, - min_eval_loss, stall_num, best_step, summary_writer - ) - model.train() - - # delete ckpts over args.saving_limit - if accelerator.is_main_process and args.saving_limit: - delete_ckpts_over_limits(args.output_dir, args.saving_limit, best_step) - - # early stoppin when stalling more than args.early_stopping_stall_num - if args.early_stopping and stall_num == args.early_stopping_stall_num: - accelerator.print(f"[WARNING] Early stopping at {completed_steps}") - break - - if completed_steps >= args.max_train_steps: - break - accelerator.wait_for_everyone() - - # epoch checkpointing - if args.epoch_checkpointing: - output_dir = f"epoch_{epoch}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args) - - summary_writer.close() - - # final save - output_dir = f"final_step_{completed_steps}" - if args.output_dir is not None: - output_dir = os.path.join(args.output_dir, output_dir) - accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args) diff --git a/mftcoder_accelerate/src/run_offline_tokenization.sh b/mftcoder_accelerate/src/run_offline_tokenization.sh new file mode 100644 index 0000000..ed916da --- /dev/null +++ b/mftcoder_accelerate/src/run_offline_tokenization.sh @@ -0,0 +1,13 @@ +MODEL_PATH= +DATA_PATH= +DATASET_NAME= +OUTPUT_PATH= + +python offline_tokenization/concat_sst_bin_tokenization.py \ +--model-path ${MODEL_PATH} \ +--data-path ${DATA_PATH} \ +--dataset-name ${DATASET_NAME} \ +--output-path ${OUTPUT_PATH} \ +--parallel 16 \ +--seq-length 4096 \ +--sample-percent 1.0 diff --git a/mftcoder_accelerate/src/tokenizer/__init__.py b/mftcoder_accelerate/src/tokenizer/__init__.py index 12ec210..20e88bb 100644 --- a/mftcoder_accelerate/src/tokenizer/__init__.py +++ b/mftcoder_accelerate/src/tokenizer/__init__.py @@ -1 +1,3 @@ from .tokenizer import build_tokenizer +from .tokenizer import init_tokenizer +from .chat_template import MFTCoder_template \ No newline at end of file diff --git a/mftcoder_accelerate/src/tokenizer/chat_template.py b/mftcoder_accelerate/src/tokenizer/chat_template.py index e04a3f9..3d2ad03 100644 --- a/mftcoder_accelerate/src/tokenizer/chat_template.py +++ b/mftcoder_accelerate/src/tokenizer/chat_template.py @@ -4,25 +4,17 @@ # store possible chat_template for tokenizers to prepare input string # -------------------------------------------------- Import ------------------------------------------------------------ -from transformers import ( - AutoTokenizer -) - -# ----------------------------------------------- func and class ------------------------------------------------------- -instruction_template = ( - "{% for message in messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ '[INST] ' + message['content'] + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ message['content'] + eos_token}}" - "{% else %}" - "{{ raise_exception('Only user and assistant roles are supported!') }}" - "{% endif %}" - "{% endfor %}" -) +""" +Usage: +tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) +messages = [ + {"role": "system", "content": "Be smart"}, + {"role": "human", "content": "Hello, how are you?"}, + {"role": "bot", "content": "I'm doing great. How can I help you today?"}, + {"role": "human", "content": "I'd like to show off how chat templating works!"}, +] +prompts = tokenizer.apply_chat_template(message, chat_template=MFTCoder_template, tokenize=False, add_generation_prompt=True) +""" MFTCoder_template = ( "{% if messages[0]['role'] == 'system' %}" @@ -42,17 +34,17 @@ "{% set content = '' %}" "{% endif %}" "{% if message['role'] == 'user' or message['role'] == 'human' %}" - "{{ content + 'user\n' + message['content'] + '\n' }}" + "{{ content + 'human\n' + message['content'] + '\n' }}" "{% elif message['role'] == 'assistant' or message['role'] == 'bot' %}" - "{{ 'assistant\n' + message['content'] + '\n' + eos_token + '\n'}}" + "{{ 'bot\n' + message['content'] + '\n' + eos_token + '\n'}}" "{% else %}" "{{ raise_exception('Only user/human and assistant/bot roles are supported!') }}" "{% endif %}" "{% endfor %}" "{% if add_generation_prompt %}" - "{{ 'assistant\n' }}" + "{{ 'bot\n' }}" "{% endif %}" ) -if __name__ == '__main__': +if __name__ == "__main__": pass diff --git a/mftcoder_accelerate/src/tokenizer/tokenizer.py b/mftcoder_accelerate/src/tokenizer/tokenizer.py index 5765ffd..bc3ab56 100644 --- a/mftcoder_accelerate/src/tokenizer/tokenizer.py +++ b/mftcoder_accelerate/src/tokenizer/tokenizer.py @@ -1,35 +1,75 @@ """ # @author Chaoyu Chen # @date 2023/6/19 - -Build tokenizer """ - import numpy as np from typing import List, Union from utils.common_utils import print_rank_0 -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoConfig +from tokenizer.chat_template import MFTCoder_template + + +def init_tokenizer(path): + """ + Init a Huggingface tokenizer, parsing eos_token from the tokenizer_config then config. + Set pad_token same as eos_token for easy life. + :param path: model path or tokenizer path + :return: Tokenizer (TokenizerFast is preferred) + """ + # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False, legacy=False) + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + config, unused_kwargs = AutoConfig.from_pretrained(path, trust_remote_code=True, return_unused_kwargs=True) + + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id: + print(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer") + eos_token_id = tokenizer.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(eos_token_id) + elif hasattr(tokenizer, "eos_token") and tokenizer.eos_token: + print(f"Initial eos_token {tokenizer.eos_token} from tokenizer") + eos_token = tokenizer.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) + elif hasattr(config, "eos_token_id") and config.eos_token_id: + print(f"Initial eos_token_id {config.eos_token_id} from config.json") + eos_token_id = config.eos_token_id + eos_token = tokenizer.convert_ids_to_tokens(config.eos_token_id) + elif hasattr(config, "eos_token") and config.eos_token: + print(f"Initial eos_token {config.eos_token} from config.json") + eos_token = config.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(config.eos_token) + else: + raise ValueError( + "No available eos_token or eos_token_id, please provide eos_token by params or eos_token_id by config.json" + ) + try: + tokenizer.eos_token = eos_token + tokenizer.eos_token_id = eos_token_id + # set pad_token to be same as eos_token, it is ok because is will be masked out. + tokenizer.pad_token = eos_token + tokenizer.pad_token_id = eos_token_id + except: + print(f"[WARNING]Cannot set tokenizer.eos_token") + + tokenizer.add_bos_token = False + tokenizer.add_eos_token = False + tokenizer.chat_template = MFTCoder_template + print_rank_0(f"Tokenizer: {type(tokenizer)}") + print_rank_0(f"Length of tokenizer: {len(tokenizer)}") + print_rank_0(f"build_tokenizer pad_token_id: {tokenizer.pad_token_id}, eos_token_id: {tokenizer.eos_token_id}") + print_rank_0(f"build_tokenizer pad_token : {tokenizer.pad_token}, eos_token: {tokenizer.eos_token}") + + return tokenizer def build_tokenizer(args): """Initialize tokenizer.""" - print_rank_0("> building {} tokenizer ...".format(args.tokenizer_type)) + print_rank_0(f"> building {args.tokenizer_type} tokenizer ...") # Select and instantiate the tokenizer. if args.tokenizer_type.lower() == "AutoTokenizer".lower(): assert args.pretrained_model_path is not None - tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True) - - tokenizer.eod_id = tokenizer.convert_tokens_to_ids(args.eos_token) - tokenizer.pad_id = tokenizer.convert_tokens_to_ids(args.pad_token) - - print_rank_0(f"Tokenizer: {tokenizer}\nLength of tokenizer: {len(tokenizer)}") - print_rank_0(f"build_tokenizer PAD id: {tokenizer.pad_id}, EOD id: {tokenizer.eod_id}") - print_rank_0(f"build_tokenizer PAD token : {args.pad_token}, EOD token: {args.eos_token}") + tokenizer = init_tokenizer(args.pretrained_model_path) else: - raise NotImplementedError( - "{} tokenizer is not " "implemented.".format(args.tokenizer_type) - ) + raise NotImplementedError(f"{args.tokenizer_type} tokenizer is not implemented.") # Add vocab size. args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) @@ -38,7 +78,7 @@ def build_tokenizer(args): def _vocab_size_with_padding(orig_vocab_size, args): - """Pad vocab size so that it is divisible by model parallel size and + """Pad vocab size thus it is divisible by model parallel size and still having GPU friendly size.""" after = orig_vocab_size diff --git a/mftcoder_accelerate/src/utils/__init__.py b/mftcoder_accelerate/src/utils/__init__.py index 0cf9434..0bd6cec 100644 --- a/mftcoder_accelerate/src/utils/__init__.py +++ b/mftcoder_accelerate/src/utils/__init__.py @@ -1,2 +1,2 @@ from .common_utils import * -from .auto_accelerate_utils import * \ No newline at end of file +from .loss_utils import * diff --git a/mftcoder_accelerate/src/utils/agd.py b/mftcoder_accelerate/src/utils/agd.py index bb654a9..11929e3 100644 --- a/mftcoder_accelerate/src/utils/agd.py +++ b/mftcoder_accelerate/src/utils/agd.py @@ -83,22 +83,14 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) if group["amsgrad"]: # Maintains max of all exp. moving avg. of sq. grad. values - state["max_exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) if group["win"]: - state["z"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["z"] = torch.zeros_like(p, memory_format=torch.preserve_format) exp_avg, exp_avg_sq = ( state["exp_avg"], @@ -116,8 +108,7 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: update = ( exp_avg * (1 / bias_correction1) if state["step"] == 1 - else exp_avg * (1 / bias_correction1) - - exp_avg_old * (1 / bias_correction1_old) + else exp_avg * (1 / bias_correction1) - exp_avg_old * (1 / bias_correction1_old) ) exp_avg_sq.mul_(beta2).addcmul_(update, update, value=1 - beta2) diff --git a/mftcoder_accelerate/src/utils/auto_accelerate_utils.py b/mftcoder_accelerate/src/utils/auto_accelerate_utils.py deleted file mode 100644 index c35ca9b..0000000 --- a/mftcoder_accelerate/src/utils/auto_accelerate_utils.py +++ /dev/null @@ -1,224 +0,0 @@ -import sys -import torch -from utils.common_utils import print_rank_0, TASK2ID, ID2TASK -from torch.nn import CrossEntropyLoss -import torch.nn.functional as F -from dataclasses import dataclass -import numpy as np - - -def get_task_mask(task_id): - task_num = len(TASK2ID) - task_mask = torch.zeros(task_id.shape[0], task_num) - task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 - - return task_mask - - -def get_task_loss(task_losses, task_id): # TODO - # fix task order - task_loss_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device) - # count task samples - task_num_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device) - for i in range(len(task_id)): - task_num_per_batch[task_id[i][0]] += 1 - task_loss_per_batch[task_id[i][0]] = task_losses[task_id[i][0]] - - return task_loss_per_batch, task_num_per_batch - - -def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_mask=None, task_weights=None): - """ - loss function for MFT loss - :param outputs: - :param labels: - :param task_mask: - :param task_id: - :param weighted_loss_mode: - :param loss_mask: - :return: - """ - # task_id shape: [[1], [2], [4], [3], ..., [1]] - weighted = weighted_loss_mode - lm_logits = outputs["logits"] - labels = labels.to(device=lm_logits.device) - task_mask = task_mask.to(device=lm_logits.device) - task_id = task_id.to(device=lm_logits.device) - shift_logits = lm_logits.contiguous() - labels = labels.contiguous() - if task_weights is None: - task_weights = torch.ones(len(ID2TASK)).to(device=lm_logits.device) / len(ID2TASK) - - bsz, seq_len = labels.shape - # loss_mask = None - if loss_mask is None: - ineffective_tokens_per_sample = (labels == -100).sum(dim=1) - effective_tokens_per_sample = - (ineffective_tokens_per_sample - seq_len) - effective_tokens = bsz * seq_len - ineffective_tokens_per_sample.sum() - loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100) - else: - loss_mask = loss_mask.to(device=lm_logits.device) - loss_fct = CrossEntropyLoss(reduction='none') - losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) # [B * L, 1] - losses = losses.contiguous().view(bsz, -1) - token_losses = losses.clone().detach().float() if loss_mask is None else losses.clone().detach().float() * loss_mask # [B, L] - task_mask_trans = torch.transpose(task_mask, 0, 1) - unique_id = torch.unique(task_id) - if weighted_loss_mode == "case3" or weighted_loss_mode == "case4" or weighted_loss_mode == "selfpaced": - loss = 0.0 - weights_sum = 0.0 - for i, w in enumerate(unique_id): - row_idx = torch.squeeze(task_id) == w.item() - task_weight = float(task_weights[w.item()]) - weights_sum += task_weight - if weighted_loss_mode == "case3" or weighted_loss_mode == "selfpaced": - if loss_mask is None: - loss += torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) * task_weight - else: - loss += torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) * task_weight - elif weighted_loss_mode == "case4": - if loss_mask is None: - loss += torch.mean(torch.sum(losses, dim=1)[row_idx] / effective_tokens_per_sample[row_idx]) * task_weight - else: - loss += torch.mean(torch.sum(losses * loss_mask, dim=1)[row_idx] / torch.sum(loss_mask, dim=1)[row_idx]) * task_weight - - # loss /= len(unique_id) - loss /= weights_sum - - elif weighted_loss_mode == "case2": - if loss_mask is None: - loss = torch.mean(torch.sum(losses, dim=1) / effective_tokens_per_sample) - else: - loss = torch.mean(torch.sum(losses * loss_mask, dim=1) / torch.sum(loss_mask, dim=1)) - elif weighted_loss_mode == "case1": - # flatten losses & loss_mask tensor - if loss_mask is None: - losses = losses.view(-1) - loss = torch.sum(losses) / effective_tokens - else: - loss_mask = loss_mask.view(-1) - losses = losses.view(-1) - loss = torch.sum(losses * loss_mask) / loss_mask.sum() - - # fix task order - task_loss = torch.zeros(len(ID2TASK)).to(device=task_id.device) - task_num = torch.zeros(len(ID2TASK)).to(device=task_id.device) - for i, w in enumerate(unique_id): - row_idx = torch.squeeze(task_id) == w.item() - if loss_mask is None: - task_loss[w] = torch.sum(token_losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) - task_num[w] = len(effective_tokens_per_sample[row_idx]) - else: - task_loss[w] = torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) - - return loss, task_loss, task_num - - -class MFTLossStatus: - def __init__(self): - super(MFTLossStatus, self).__init__() - - -class SelfpacedStatus(MFTLossStatus): - def __init__(self, - selfpaced_scale_factor=50, - selfpaced_interval=1, - selfpaced_history_length=100, - selfpaced_sample_valid_num=1, - valid_dataloader=None - ): - - super(SelfpacedStatus, self).__init__() - self.selfpaced_scale_factor = selfpaced_scale_factor - self.selfpaced_interval = selfpaced_interval - self.selfpaced_history_length = selfpaced_history_length - self.selfpaced_sample_valid_num = selfpaced_sample_valid_num - self.valid_dataloader = valid_dataloader - self.valid_dataloader_length = len(valid_dataloader) - self.valid_iterator = iter(valid_dataloader) - self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK)) - self.history_task_valid_loss = torch.zeros((selfpaced_history_length, len(ID2TASK))) - self.log_per_task_weight = torch.zeros(len(ID2TASK)) - - def selfpaced_evaluate(self, model, v_batch, per_task_weight=None, selfpaced_status=None): - model.eval() - with torch.no_grad(): - valid_outputs = model( - input_ids=v_batch['input_ids'], - attention_mask=v_batch['attention_mask'], - position_ids=v_batch['position_ids'] - ) - - _, valid_task_loss, valid_task_num = loss_func_mft( - outputs=valid_outputs, - labels=v_batch['labels'], - task_mask=v_batch['task_mask'], - task_id=v_batch['task_id'], - weighted_loss_mode='selfpaced', - loss_mask=v_batch['loss_mask'], - task_weights=None - ) - - torch.distributed.all_reduce(valid_task_loss, op=torch.distributed.ReduceOp.SUM) - valid_task_loss /= torch.distributed.get_world_size() - model.train() - return valid_task_loss - - def compute_per_task_weight(self, completed_steps=None): - task_slope_fitting = torch.ones(len(ID2TASK)) - history_steps = torch.arange(completed_steps - self.selfpaced_history_length, completed_steps, 1) # DEBUG: step < 0 - transpose_history_task_valid_loss = self.history_task_valid_loss.transpose(0, 1) - for i in range(len(ID2TASK)): - per_history_task_valid_loss = transpose_history_task_valid_loss[i] - task_slope_fitting[i] = self.fit_window_point(history_steps, per_history_task_valid_loss, - history=self.selfpaced_history_length, type='slope') - slope_sum_abs = torch.sum(torch.abs(task_slope_fitting)) - - if slope_sum_abs == 0: - per_task_weight = torch.ones(len(ID2TASK)) / len(ID2TASK) - else: - # print_rank_0(f"[step={completed_steps}][slope sum abs={slope_sum_abs}]") - normalize_slope = len(ID2TASK) * task_slope_fitting / slope_sum_abs - print_rank_0(f'normalize_slope: {normalize_slope}') - score = F.softmax(normalize_slope, dim=-1) * (-1 * normalize_slope) - print_rank_0(f'score: {score}') - per_task_weight = F.softmax(self.selfpaced_scale_factor * score, dim=-1) - print_rank_0(f'per_task_weight: {per_task_weight}') - - return per_task_weight - - def fit_window_point(self, x, y, history=10, type='slope'): - - nonzero_index = torch.squeeze(torch.nonzero(y), dim=1) - y = torch.index_select(y, 0, nonzero_index) - x = torch.index_select(x, 0, nonzero_index) - - ws = torch.flip(1 ** torch.arange(len(y)), dims=[0]) - ws = ws.float() - - if len(y) >= 2: - if type == 'slope': - X = torch.stack((x, torch.ones_like(x))).T - X = X.float() - else: - X = torch.stack((x ** 2, x, torch.ones_like(x))).T - w = torch.linalg.solve(X.T @ (ws[:, None] * X), X.T @ (ws * y)) - - result = w[0] - else: - result = 0.0 - - return result - - def sample_valid_batch(self, model, completed_steps): - self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK)) - for i in range(self.selfpaced_sample_valid_num): - if (self.selfpaced_sample_valid_num * completed_steps // self.selfpaced_interval + i) % self.valid_dataloader_length == 0: - self.valid_iterator = iter(self.valid_dataloader) - v_batch = next(self.valid_iterator) - valid_task_loss = self.selfpaced_evaluate(model, v_batch) - self.valid_task_loss_accumulated += valid_task_loss.detach().cpu() - self.valid_task_loss_accumulated /= self.selfpaced_sample_valid_num - self.history_task_valid_loss = torch.cat((self.history_task_valid_loss, torch.unsqueeze(self.valid_task_loss_accumulated, dim=0))) - if len(self.history_task_valid_loss) > self.selfpaced_history_length: - self.history_task_valid_loss = self.history_task_valid_loss[len(self.history_task_valid_loss) - self.selfpaced_history_length:] diff --git a/mftcoder_accelerate/src/utils/common_utils.py b/mftcoder_accelerate/src/utils/common_utils.py index 48d75e1..7b6ea30 100644 --- a/mftcoder_accelerate/src/utils/common_utils.py +++ b/mftcoder_accelerate/src/utils/common_utils.py @@ -1,10 +1,29 @@ import os import math import torch +from packaging import version +import importlib TASK2ID = {} ID2TASK = {} + +def is_flash_attn_2_available(): + + # Let's add an extra check to see if cuda is available + + if not torch.cuda.is_available(): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4") + else: + return False + + def print_rank_0(*message): """If distributed is initialized print only on rank 0.""" if torch.distributed.is_initialized(): @@ -92,30 +111,24 @@ def get_tflops_new(args, batch_size, seq_len, step_time): L = args.num_hidden_layers h = args.hidden_size V = args.vocab_size - flops = (96 * batch_size * sl * L * h * h * (1 + sl / (6 * h) + V / (16 * L * h)) / step_time) + flops = 96 * batch_size * sl * L * h * h * (1 + sl / (6 * h) + V / (16 * L * h)) / step_time return human_readable_flops(flops) -def get_tflops_megatron(total_model_param, hidden_size, num_hidden_layers, - batch_size_per_device, seq_len, step_time): +def get_tflops_megatron(total_model_param, hidden_size, num_hidden_layers, batch_size_per_device, seq_len, step_time): ff = total_model_param * 6 attn = seq_len * hidden_size * num_hidden_layers * 60 - flops = ( - batch_size_per_device - * seq_len - * (ff + attn) - / step_time - ) + flops = batch_size_per_device * seq_len * (ff + attn) / step_time return human_readable_flops(flops) def generate_task_id(data_paths): - data_prefixes = list(data_paths[1:-1].split(',')) + data_prefixes = list(data_paths[1:-1].split(",")) print("data paths: ") print(data_prefixes) for i, prefix in enumerate(data_prefixes): - task_name = prefix.split('/')[-1] + task_name = prefix.split("/")[-1] TASK2ID[task_name] = i ID2TASK[i] = task_name diff --git a/mftcoder_accelerate/src/utils/loss_utils.py b/mftcoder_accelerate/src/utils/loss_utils.py new file mode 100644 index 0000000..5ca7c73 --- /dev/null +++ b/mftcoder_accelerate/src/utils/loss_utils.py @@ -0,0 +1,365 @@ +import sys +import torch +from utils.common_utils import print_rank_0, TASK2ID, ID2TASK +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +from dataclasses import dataclass +import numpy as np +from typing import List, Optional, Tuple, Union + + +def get_task_mask(task_id): + task_num = len(TASK2ID) + task_mask = torch.zeros(task_id.shape[0], task_num) + task_mask[torch.arange(task_id.size(0)).unsqueeze(1), task_id] = 1 + + return task_mask + + +def get_task_loss(task_losses, task_id): # TODO + # fix task order + task_loss_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device) + # count task samples + task_num_per_batch = torch.zeros(len(ID2TASK)).to(device=task_id.device) + for i in range(len(task_id)): + task_num_per_batch[task_id[i][0]] += 1 + task_loss_per_batch[task_id[i][0]] = task_losses[task_id[i][0]] + + return task_loss_per_batch, task_num_per_batch + + +def loss_func_mft(outputs, labels, task_mask, task_id, weighted_loss_mode, loss_mask=None, task_weights=None): + """ + loss function for MFT loss + :param outputs: + :param labels: + :param task_mask: + :param task_id: + :param weighted_loss_mode: + :param loss_mask: + :return: + """ + # task_id shape: [[1], [2], [4], [3], ..., [1]] + weighted = weighted_loss_mode + lm_logits = outputs["logits"] + labels = labels.to(device=lm_logits.device) + task_mask = task_mask.to(device=lm_logits.device) + task_id = task_id.to(device=lm_logits.device) + shift_logits = lm_logits.contiguous() + labels = labels.contiguous() + if task_weights is None: + task_weights = torch.ones(len(ID2TASK)).to(device=lm_logits.device) / len(ID2TASK) + + bsz, seq_len = labels.shape + # loss_mask = None + if loss_mask is None: + ineffective_tokens_per_sample = (labels == -100).sum(dim=1) + effective_tokens_per_sample = -(ineffective_tokens_per_sample - seq_len) + effective_tokens = bsz * seq_len - ineffective_tokens_per_sample.sum() + loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100) + else: + loss_mask = loss_mask.to(device=lm_logits.device) + loss_fct = CrossEntropyLoss(reduction="none") + losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) # [B * L, 1] + losses = losses.contiguous().view(bsz, -1) + token_losses = ( + losses.clone().detach().float() if loss_mask is None else losses.clone().detach().float() * loss_mask + ) # [B, L] + task_mask_trans = torch.transpose(task_mask, 0, 1) + unique_id = torch.unique(task_id) + if weighted_loss_mode == "case3" or weighted_loss_mode == "case4" or weighted_loss_mode == "coba": + loss = 0.0 + weights_sum = 0.0 + for i, w in enumerate(unique_id): + row_idx = torch.squeeze(task_id) == w.item() + task_weight = float(task_weights[w.item()]) + weights_sum += task_weight + if weighted_loss_mode == "case3" or weighted_loss_mode == "coba": + if loss_mask is None: + loss += ( + torch.sum(losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) * task_weight + ) + else: + loss += torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) * task_weight + elif weighted_loss_mode == "case4": + if loss_mask is None: + loss += ( + torch.mean(torch.sum(losses, dim=1)[row_idx] / effective_tokens_per_sample[row_idx]) + * task_weight + ) + else: + loss += ( + torch.mean(torch.sum(losses * loss_mask, dim=1)[row_idx] / torch.sum(loss_mask, dim=1)[row_idx]) + * task_weight + ) + + # loss /= len(unique_id) + loss /= weights_sum + + elif weighted_loss_mode == "case2": + if loss_mask is None: + loss = torch.mean(torch.sum(losses, dim=1) / effective_tokens_per_sample) + else: + loss = torch.mean(torch.sum(losses * loss_mask, dim=1) / torch.sum(loss_mask, dim=1)) + elif weighted_loss_mode == "case1": + # flatten losses & loss_mask tensor + if loss_mask is None: + # losses = losses.view(-1) + loss = torch.sum(losses.view(-1)) / effective_tokens + else: + # loss_mask = loss_mask.view(-1) + # losses = losses.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask.view(-1)) / loss_mask.view(-1).sum() + + # fix task order + task_loss = torch.zeros(len(ID2TASK)).to(device=task_id.device) + task_num = torch.zeros(len(ID2TASK)).to(device=task_id.device) + for i, w in enumerate(unique_id): + row_idx = torch.squeeze(task_id) == w.item() + if loss_mask is None: + task_loss[w] = torch.sum(token_losses[row_idx, :]) / torch.sum(effective_tokens_per_sample[row_idx]) + task_num[w] = len(effective_tokens_per_sample[row_idx]) + else: + task_loss[w] = torch.sum((losses * loss_mask)[row_idx, :]) / torch.sum(loss_mask[row_idx, :]) + + return loss, task_loss, task_num + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MFTLossStatus: + def __init__(self): + super(MFTLossStatus, self).__init__() + + +class CoBaStatus(MFTLossStatus): + def __init__( + self, + coba_warmup_steps=100, + coba_history_length=200, + coba_tau=5, + coba_update_interval=1, + coba_sample_valid_num=1, + valid_dataloader=None, + ): + + super(CoBaStatus, self).__init__() + self.coba_warmup_steps = coba_warmup_steps + self.coba_history_length = coba_history_length + self.coba_tau = coba_tau + self.coba_update_interval = coba_update_interval + self.coba_sample_valid_num = coba_sample_valid_num + self.valid_dataloader = valid_dataloader + self.valid_dataloader_length = len(valid_dataloader) + self.valid_iterator = iter(valid_dataloader) + self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK)) + self.history_task_valid_loss = None + self.per_task_slope_list = None + self.total_slope_list = None + self.minimum_weight = 1 / (len(ID2TASK) * 10) + self.valid_task_loss_begining = torch.ones(len(ID2TASK), dtype=torch.float64) + self.log_per_task_weight = torch.zeros(len(ID2TASK)) + + def coba_evaluate(self, model, v_batch, per_task_weight=None, coba_status=None): + model.eval() + with torch.no_grad(): + valid_outputs = model( + input_ids=v_batch["input_ids"], + attention_mask=v_batch["attention_mask"], + position_ids=v_batch["position_ids"], + ) + + _, valid_task_loss, valid_task_num = loss_func_mft( + outputs=valid_outputs, + labels=v_batch["labels"], + task_mask=v_batch["task_mask"], + task_id=v_batch["task_id"], + weighted_loss_mode="coba", + loss_mask=v_batch["loss_mask"], + task_weights=None, + ) + + task_exist = (valid_task_loss != 0.0).float() + torch.distributed.all_reduce(valid_task_loss, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(task_exist, op=torch.distributed.ReduceOp.SUM) + valid_task_loss /= task_exist.clamp_(1.0) + valid_task_loss /= self.valid_task_loss_begining + model.train() + return valid_task_loss + + def compute_per_task_weight(self, completed_steps=None): + task_num = len(ID2TASK) + task_slope_fitting = torch.ones(task_num, dtype=torch.float64) + start_step = max(0, completed_steps // self.coba_update_interval - self.coba_history_length) + history_steps = torch.arange(start_step, completed_steps, 1) + for i in range(task_num): + per_task_history_valid_loss = self.history_task_valid_loss[i][-len(history_steps):] + task_slope_fitting[i] = self.fit_window_slope( + history_steps, per_task_history_valid_loss, type="slope" + ) + history_total_valid_loss, index = torch.max(self.history_task_valid_loss[:, -len(history_steps):], dim=0) + total_slope_fitting = self.fit_window_slope( + history_steps, history_total_valid_loss, type="slope" + ) + if completed_steps == self.coba_warmup_steps: + self.per_task_slope_list = task_slope_fitting.unsqueeze(1) + self.total_slope_list = total_slope_fitting.unsqueeze(0) + else: + self.per_task_slope_list = torch.cat((self.per_task_slope_list, task_slope_fitting.unsqueeze(1)), dim=-1) + self.total_slope_list = torch.cat((self.total_slope_list, total_slope_fitting.unsqueeze(0)), dim=0) + + # Relative Convergence Score + normalize_task_slope = task_num * task_slope_fitting / task_slope_fitting.abs().sum() + rcs = F.softmax(normalize_task_slope, dim=-1) + + # Absolute Convergence Score + history_per_task_slope_list = self.per_task_slope_list[:, start_step:] + reverse_normailize_iter_slope = -len(history_per_task_slope_list[0]) * history_per_task_slope_list \ + / history_per_task_slope_list.abs().sum(dim=-1, keepdim=True) + + flatten_rn_iter_slope = reverse_normailize_iter_slope.T.reshape(-1) + current_step_rn_slope = flatten_rn_iter_slope[-task_num:] + acs = F.softmax(current_step_rn_slope, dim=-1) + + # Divergence Factor + normalize_total_iter_slope = - len(self.total_slope_list) * self.total_slope_list \ + / self.total_slope_list.abs().sum() + divergence_factor = F.softmax(normalize_total_iter_slope * self.coba_tau, dim=-1)[-1] \ + * len(self.total_slope_list) + + weight_logits = divergence_factor * rcs + (1 - divergence_factor) * acs + per_task_weight = F.softmax(weight_logits * task_num, dim=-1) + + if len((per_task_weight < self.minimum_weight).nonzero().squeeze(0)) > 0: + per_task_weight = per_task_weight * (1 - self.minimum_weight * task_num) + per_task_weight += self.minimum_weight + + return per_task_weight + + def fit_window_slope(self, x, y, type="slope"): + + y = y[y != 0] + x = x[:len(y)] + + nonzero_index = torch.squeeze(torch.nonzero(y), dim=1) + y = torch.index_select(y, 0, nonzero_index) + x = torch.index_select(x, 0, nonzero_index) + + ws = torch.flip(1 ** torch.arange(len(y)), dims=[0]) + ws = ws.double() + + if len(y) >= 2: + if type == "slope": + X = torch.stack((x, torch.ones_like(x, dtype=torch.float64))).T + X = X.double() + else: + X = torch.stack((x ** 2, x, torch.ones_like(x, dtype=torch.float64))).T + + # implementation for numpy + # X_np = X.T @ (ws[:, None] * X) + # Y_np = X.T @ (ws * y) + # w = torch.from_numpy(np.linalg.solve(X_np.numpy(), Y_np.numpy())) + + # implementation for torch + w = torch.linalg.solve(X.T @ (ws[:, None] * X), X.T @ (ws * y)) + + result = w[0] + else: + result = 0.0 + + return result + + def sample_valid_batch(self, model, completed_steps): + self.valid_task_loss_accumulated = torch.zeros(len(ID2TASK), dtype=torch.float64) + for i in range(self.coba_sample_valid_num): + if ( + self.coba_sample_valid_num * completed_steps // self.coba_update_interval + i + ) % self.valid_dataloader_length == 0: + self.valid_iterator = iter(self.valid_dataloader) + v_batch = next(self.valid_iterator) + else: + v_batch = next(self.valid_iterator) + valid_task_loss = self.coba_evaluate(model, v_batch) + self.valid_task_loss_accumulated += valid_task_loss.detach().cpu().double() + + self.valid_task_loss_accumulated /= self.coba_sample_valid_num + if self.history_task_valid_loss is None and completed_steps >= 1: + self.history_task_valid_loss = self.valid_task_loss_accumulated.unsqueeze(1) + elif self.history_task_valid_loss is not None: + self.history_task_valid_loss = torch.cat( + (self.history_task_valid_loss, self.valid_task_loss_accumulated.unsqueeze(1)), dim=-1 + ) diff --git a/mftcoder_accelerate/src/utils/model_mapping.py b/mftcoder_accelerate/src/utils/model_mapping.py new file mode 100644 index 0000000..8592e86 --- /dev/null +++ b/mftcoder_accelerate/src/utils/model_mapping.py @@ -0,0 +1,67 @@ +""" + @author qumu + transformers==4.40 is stable now +""" + +# Models that Transformers support Code and FA2 when flash_attn>=2.1.0 +from transformers import ( + GPTNeoXForCausalLM, + GPTBigCodeForCausalLM, + LlamaForCausalLM, + MistralForCausalLM, + MixtralForCausalLM, + PhiForCausalLM, + GemmaForCausalLM, + Qwen2ForCausalLM, + Qwen2MoeForCausalLM, + Starcoder2ForCausalLM, +) + +# model in local model dir and support transformers FA2 +from model.deepseek_v2.modeling_deepseek import DeepseekV2ForCausalLM + +# model in local model and self-contained +from model.aquila2.modeling_aquila import AquilaForCausalLM +from model.baichuan2.modeling_baichuan import BaichuanForCausalLM +from model.qwen.modeling_qwen import QWenLMHeadModel +from model.chatglm2.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration2 +from model.chatglm3.modeling_chatglm import ChatGLMForConditionalGeneration as ChatGLMForConditionalGeneration3 + +# from model.phi.modeling_mixformer_sequential import MixFormerSequentialForCausalLM + +MODEL_TYPES = { + "aquila2": AquilaForCausalLM, + "baichuan": BaichuanForCausalLM, + "chatglm2": ChatGLMForConditionalGeneration2, + "chatglm3": ChatGLMForConditionalGeneration3, + "code_llama": LlamaForCausalLM, + "deepseek": LlamaForCausalLM, + "gpt_neox": GPTNeoXForCausalLM, + "llama": LlamaForCausalLM, + "mistral": MistralForCausalLM, + "mixtral": MixtralForCausalLM, + "phi": PhiForCausalLM, + "qwen": QWenLMHeadModel, + "starcoder": GPTBigCodeForCausalLM, + "qwen2": Qwen2ForCausalLM, + "gemma": GemmaForCausalLM, + "qwen2_moe": Qwen2MoeForCausalLM, + "starcoder2": Starcoder2ForCausalLM, + "deepseek_v2": DeepseekV2ForCausalLM, +} + +SUPPORT_IN_TRANSFORMERS = [ + "code_llama", + "llama", + "deepseek", + "mistral", + "mixtral", + "gpt_neox", + "phi", + "starcoder", + "qwen2", + "qwen2_moe", + "gemma", + "starcoder2", + "deepseek_v2", +] diff --git a/mftcoder_accelerate/src/xxpo/custom_callbacks.py b/mftcoder_accelerate/src/xxpo/custom_callbacks.py new file mode 100644 index 0000000..f38fa70 --- /dev/null +++ b/mftcoder_accelerate/src/xxpo/custom_callbacks.py @@ -0,0 +1,99 @@ +""" +Customized Callbacks to use with the Trainer class and customize the training loop. +""" + +import copy +import dataclasses +import json +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import numpy as np +from tqdm.auto import tqdm + +from transformers.trainer_utils import IntervalStrategy, has_length +from transformers.training_args import TrainingArguments +from transformers.utils import logging +from transformers import TrainerCallback + +logger = logging.get_logger(__name__) + + +class CustomProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + """ + + def __init__(self): + self.training_bar = None + self.prediction_bar = None + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_world_process_zero and state.global_step % args.logging_steps == 0: + self.training_bar.update(args.logging_steps) + self.current_step = state.global_step + # pass + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + # if state.is_world_process_zero and has_length(eval_dataloader): + # if self.prediction_bar is None: + # self.prediction_bar = tqdm( + # total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True + # ) + # self.prediction_bar.update(1) + pass + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_predict(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero and self.training_bar is not None: + # avoid modifying the logs object as it is shared between callbacks + logs = copy.deepcopy(logs) + # _ = logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in logs: + logs["epoch"] = round(logs["epoch"], 2) + # self.training_bar.write(str(logs)) + logger.info(logs) + + def on_train_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.close() + self.training_bar = None + + +class PrinterCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + + +class LogCallback(TrainerCallback): + """ + A bare [`TrainerCallback`] that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + logger.info(logs) \ No newline at end of file diff --git a/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py b/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py new file mode 100644 index 0000000..4c93520 --- /dev/null +++ b/mftcoder_accelerate/src/xxpo/xxpo_accelerate.py @@ -0,0 +1,484 @@ +""" +# @author qumu +# @date 2023/12/11 +# @module mft_accelerate.py + +Accelerate + DeepSpeed/FSDP + QLoRA/LoRA/Full + DPO/RPO/ORPO + +Entry +""" + +import os +import sys +import argparse +import math +import logging +import json +import time +from datetime import timedelta +from tqdm.auto import tqdm +from dataclasses import dataclass +from typing import Dict, Optional, Union, List + +import datasets +from datasets import Dataset, load_dataset, concatenate_datasets + +import torch +from torch.utils.data import DataLoader +from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig + +import transformers +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + get_linear_schedule_with_warmup, + set_seed, + BitsAndBytesConfig, + get_scheduler, +) +from peft import ( + LoraConfig, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + PeftModel, +) +from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin, DataLoaderConfiguration +from accelerate.logging import get_logger +from accelerate.utils import InitProcessGroupKwargs + +# insert src as import path +current_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(os.path.dirname(current_path)) +sys.path.insert(0, parent_dir) + +from tokenizer import build_tokenizer + +from utils.common_utils import print_rank_0, generate_task_id, TASK2ID, ID2TASK +from utils.model_mapping import MODEL_TYPES, SUPPORT_IN_TRANSFORMERS + +logger = get_logger(__name__) + + +from trl import ( + DPOConfig, + DPOTrainer, + ORPOConfig, + ORPOTrainer, + ModelConfig, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) + +from xxpo.xxpo_arguments import XXPOTrainArgs +from xxpo.custom_callbacks import CustomProgressCallback +from xxpo.custom_callbacks import LogCallback + + +def pprint_args(args, accelerator): + # 计算所有键的最大字符串长度 + max_key_length = max(len(str(key)) for key in vars(args).keys()) + + message = "" + message += "====" * 60 + "\n" + message += "\n".join([f"{k:<{max_key_length}} : {v}" for k, v in vars(args).items()]) + "\n" + message += "====" * 60 + "\n" + accelerator.print(message) + accelerator.print("GPU: {}".format(torch.cuda.current_device())) + + +def prepare_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_config", type=str, default=None) + + parser.add_argument("--data_paths", type=str, default=None) + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--tb_dir", type=str, default=None) + parser.add_argument("--pretrained_model_path", type=str, default=None) + parser.add_argument("--micro_batch_size", type=int, default=None) + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--distributed_type", type=str, default="deepspeed") + + parsed = parser.parse_args() + # get json configs + with open(parsed.train_config, "r") as f: + train_config = json.load(f) + + # parse args from cofig.json + args = XXPOTrainArgs(**train_config) + + # override args by cli arguments + if parsed.data_paths: + args.data_paths = parsed.data_paths + if parsed.output_dir: + args.output_dir = parsed.output_dir + if parsed.tb_dir: + args.tb_dir = parsed.tb_dir + if parsed.pretrained_model_path: + args.pretrained_model_path = parsed.pretrained_model_path + args.vocab_file = parsed.pretrained_model_path + if parsed.micro_batch_size: + args.per_device_train_batch_size = parsed.micro_batch_size + args.per_device_eval_batch_size = parsed.micro_batch_size + if parsed.model_type: + args.model_type = parsed.model_type + + args.distributed_type = parsed.distributed_type + + # refactor args + + if args.peft_type == "qlora": + print_rank_0(f"[INFO] args.peft_type is set 'qlora', setting quantization to '4bit'") + args.quantization = "4bit" + else: + args.quantization = None + + args.vocab_file = args.pretrained_model_path + + return args + + +def get_model(args, accelerator): + ModelClass = MODEL_TYPES[args.model_type] + if args.model_type in SUPPORT_IN_TRANSFORMERS: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported by Transformers") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + attn_implementation=args.attn_implementation, + torch_dtype=torch.bfloat16, + # device_map=get_kbit_device_map() if args.quantization == "4bit" else None, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), + ) + else: + accelerator.print(f"[INFO] Model Type {args.model_type} is supported in our local model dir for remote code") + model = ModelClass.from_pretrained( + args.pretrained_model_path, + torch_dtype=torch.bfloat16, + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=(args.quantization == "4bit"), + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_quant_storage=torch.bfloat16, + ) + if args.quantization == "4bit" + else None + ), + ) + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + + return model + + +def chatml_to_dpo_format( + data_file: str, + tokenizer, + sanity_check: bool = False, + cache_dir: Optional[str] = None, + num_proc=16, +) -> Dataset: + """Load the standard-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'chosen': List[dict], chatml + 'rejected': List[dict], chatml + } + """ + + dataset = load_dataset( + "json", + split="train", + data_files=data_file, + cache_dir=cache_dir, + verification_mode="no_checks", + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 100))) + + def process(samples): + samples["prompt"] = [ + tokenizer.apply_chat_template(chosen[:-1], tokenize=False, add_generation_prompt=True) + for chosen in samples["chosen"] + ] + samples["chosen"] = [chosen[-1]["content"] + tokenizer.eos_token for chosen in samples["chosen"]] + samples["rejected"] = [rejected[-1]["content"] + tokenizer.eos_token for rejected in samples["rejected"]] + return samples + + return dataset.map( + process, + batched=True, + num_proc=num_proc, + # remove_columns=original_columns, + ) + + +def main(): + t0 = time.time() + # os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["HF_HUB_OFFLINE"] = "false" + # get input args, set TASK2ID, ID2TASK, refactor args + args = prepare_args() + + # fix randomness + if args.seed is not None: + set_seed(args.seed) + + # define accelerator + init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.init_timeout_seconds)) + + if args.distributed_type and args.distributed_type.lower() == "fsdp": + fsdp_plugin = FullyShardedDataParallelPlugin( + # state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + # optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + limit_all_gathers=True, + sync_module_states=True, + use_orig_params=True, + cpu_offload=False, + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + else: + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True), + kwargs_handlers=[init_process_kwargs], + ) + + # print key infos + accelerator.print("In dpo_accelerate.py, sys path:", sys.path) + accelerator.print(f"transformers.__version__: {transformers.__version__}") + + # get world_size + args.world_size = accelerator.num_processes + + # backup args + pprint_args(args, accelerator) + if accelerator.is_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + with open(os.path.join(args.output_dir, "args.json"), "w") as f: + json.dump(args.dict(), f, indent=2) + + # deal with autoresume, args.resume_from_checkpoint prior to auto_resume from latest + + # logger + logging.basicConfig( + format="[%(asctime)s][%(levelname)s][%(name)s]%(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # get global_rank and local rank for current process + global_rank = accelerator.process_index + local_rank = accelerator.local_process_index + print(f"world_size: {args.world_size}, global_rank: {global_rank}, local_rank: {local_rank}") + + # 1. dataset + + # build tokenizer + tokenizer = build_tokenizer(args) + # tokenizer.chat_template = MFTCoder_template + + # Load the dpo dataset + all_datasets = [] + # print(args.data_paths, type(args.data_paths)) + if isinstance(args.data_paths, str): + args.data_paths = list(args.data_paths[1:-1].split(",")) + # print(f"DATA_PATHS: {args.data_paths}") + for data_file in args.data_paths: + ds = chatml_to_dpo_format(data_file=data_file, tokenizer=tokenizer, sanity_check=args.sanity_check) + all_datasets.append(ds) + + all_dataset = concatenate_datasets(all_datasets) + # all_dataset = all_dataset.filter( + # lambda x: len(x["prompt"]) + len(x["chosen"]) <= args.max_length + # and len(x["prompt"]) + len(x["rejected"]) <= args.max_length + # ) + accelerator.print(f"Length of all_dataset: {len(all_dataset)}") + + # split train/eval dataset + splits = [float(s) for s in args.data_split.split(",")][:2] + print(f"data splits: {splits}") + + all_dataset = all_dataset.train_test_split(test_size=splits[1] / sum(splits), shuffle=True, seed=args.seed) + all_dataset.flatten_indices() + + train_dataset, eval_dataset = all_dataset["train"], all_dataset["test"] + accelerator.print(f"Length of train_dataset: {len(train_dataset)}\nLength of eval_dataset: {len(eval_dataset)}") + print(eval_dataset[0]) + t1 = time.time() + logger.info(f"dataset loading time: {t1 - t0:.4f}") + + # cuda memory + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + max_memory = f"{free_in_GB - 2}GB" + n_gpus = torch.cuda.device_count() + max_memory = {i: max_memory for i in range(n_gpus)} + accelerator.print("max memory: ", max_memory, n_gpus) + + # target_modules, default all-linear for all linear layers + if args.target_modules: + target_modules = args.target_modules + else: + target_modules = "all-linear" + + # peft config + if args.peft_type: + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=target_modules, + bias="lora_only", + ) + else: + peft_config = None + + # creating base model + model = get_model(args, accelerator) + if args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + accelerator.print("Model load_in_4bit: ", args.quantization == "4bit") + + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + if hasattr(model.config, "use_logn_attn"): + model.config.use_logn_attn = False # special for qwen model + # load balance for moe training + if hasattr(model.config, "output_router_logits"): + model.config.output_router_logits = True + model_config = model.config + accelerator.print(model.config) + + t2 = time.time() + if accelerator.is_main_process: + logging.info(f"model loading time: {t2 - t1:.4f}") + + # 4. initialize training arguments: + if args.xxpo == "dpo": + ConfigClass = DPOConfig + elif args.xxpo == "orpo": + ConfigClass = ORPOConfig + logging.info(f"{args.xxpo} Used.") + + training_args = ConfigClass( + beta=args.beta, + rpo_alpha=args.rpo_alpha, + per_device_train_batch_size=args.per_device_train_batch_size, + per_device_eval_batch_size=args.per_device_eval_batch_size, + max_steps=args.max_steps, + num_train_epochs=args.num_train_epochs, + logging_steps=args.logging_steps, + save_strategy="steps", + eval_strategy="steps", + save_steps=args.save_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_checkpointing=args.gradient_checkpointing, + learning_rate=args.learning_rate, + eval_steps=args.eval_steps, + output_dir=args.output_dir, + report_to="tensorboard", + logging_dir=args.tb_dir, + max_prompt_length=args.max_prompt_length, + max_length=args.max_length, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.warmup_steps, + optim=args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="", + gradient_checkpointing_kwargs=dict(use_reentrant=args.gradient_checkpointing_use_reentrant), + seed=args.seed, + dataset_num_proc=args.dataset_num_proc, + disable_tqdm=args.disable_tqdm, + save_only_model=args.save_only_model, + save_total_limit=args.saving_limit, + ) + + # 5. initialize the DPO trainer + if not args.peft_type and args.xxpo == "dpo": + model_ref = get_model(args, accelerator) + model_ref.config.use_cache = False # silence the warnings. Please re-enable for inference! + else: + model_ref = None + + if args.xxpo == "dpo": + xxpo_trainer = DPOTrainer( + model, + ref_model=model_ref, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + ) + elif args.xxpo == "orpo": + xxpo_trainer = ORPOTrainer( + model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + ) + + # callbacks + if args.disable_tqdm: + xxpo_trainer.remove_callback(PrinterCallback) + xxpo_trainer.add_callback(LogCallback) + else: + xxpo_trainer.remove_callback(ProgressCallback) + xxpo_trainer.add_callback(CustomProgressCallback) + + # 6. train + xxpo_trainer.train() + + # 7. save + output_dir = os.path.join(args.output_dir, "epoch_final") + xxpo_trainer.save_model(output_dir) + # dpo_trainer.model.save_pretrained(output_dir) + logger.info(f"Training Finished!") + + +if __name__ == "__main__": + main() diff --git a/mftcoder_accelerate/src/xxpo/xxpo_arguments.py b/mftcoder_accelerate/src/xxpo/xxpo_arguments.py new file mode 100644 index 0000000..2b4c876 --- /dev/null +++ b/mftcoder_accelerate/src/xxpo/xxpo_arguments.py @@ -0,0 +1,170 @@ +""" +# @author Chaoyu Chen +# @date 2023/10/19 + +training arguments +""" + +from dataclasses import dataclass, asdict +from typing import List, Union + + +@dataclass +class XXPOTrainArgs: + # train data paths on shared FS + data_paths: Union[str, List[str]] + + # output dir for saving adaptors in peft or full ckpts in full-parameter training + output_dir: str + + # tensorboard dir for saving tensorboard logs + tb_dir: str + + # pretrained_model_path, on which is the model you want to train + pretrained_model_path: str + + # model type of pretrained_model_path, support llama|qwen|starcoder|baichuan|chatglm2 + model_type: str + + # train/valid/test split + data_split: str = "98,2,0" + + # lora or qlora or None(for full-parameter training) + peft_type: Union[None, str] = "qlora" + + # if qlora, 4bit will be set, else None + quantization: Union[None, str] = "4bit" + + # lora rank, the bigger, the more trainalbe parameters + lora_rank: int = 96 + + # lora alpha + lora_alpha: int = 32 + + # lora dropout + lora_dropout: float = 0.05 + + # lora targeting modules + target_modules: Union[None, str, List[str]] = None + + # dpo or orpo + xxpo: str = "dpo" + + # dpo/orpo beta + beta: float = 0.1 + + rpo_alpha: Union[None, float] = None + + # mircro train batch size + per_device_train_batch_size: int = 8 + + # micro eval batch size, always same as micro train batch size + per_device_eval_batch_size: int = 8 + + # HF AutoTokenizer is supported, maybe more types + tokenizer_type: str = "AutoTokenizer" + + # initial lr + learning_rate: float = 5e-5 + + # minimum lr + min_lr: float = 5e-6 + + # weight decay + weight_decay: float = 0.01 + + # gradient_accumulation_steps + gradient_accumulation_steps: int = 1 + + # lr_scheduler_type + lr_scheduler_type: str = "cosine" + + # optimizer_type + optimizer_type: str = "adamw_torch" + # optimizer_type: str = "paged_adamw_32bit" + + # gradient_checkpointing + gradient_checkpointing: bool = True + gradient_checkpointing_use_reentrant: bool = False + + # num of warmup_steps + warmup_steps: Union[int, float] = 0.05 + + # num_train_epochs + num_train_epochs: int = 4 + + # seed for reproducing + seed: int = 1234 + + # seq_length, context length + seq_length: int = 4096 + + save_only_model: bool = True + + # path of adaptor which is resumed from, None for not resuming training + resume_from_checkpoint: Union[None, str] = None + + # auto resume from latest ckpt if job restarted + auto_resume: bool = True + + # num of steps for logging training loss + logging_steps: int = 10 + + # num of steps for saving ckpt + save_steps: int = 100 + + # num of steps for evaluation(eval_loss), better same as checkpointing steps + eval_steps: int = 100 + + # max train steps, if None, depends on num_train_epochs + max_steps: int = -1 + + # if checkpointing every epoch, maybe True in sst + epoch_checkpointing: bool = False + + # shuffle before train/valid split + shuffle_before_split: bool = True + + # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point + early_stopping: bool = True + early_stopping_stall_num: int = 5 + + # limit num for saving ckpts, None for no limits. Used for full-parameter training to avoid exceeding disk quota. + saving_limit: Union[None, int] = None + + # ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} + attn_implementation: str = "flash_attention_2" + + # tokenizer chat template, if None, will use MFTCoder template + chat_template: Union[None, str] = None + + distributed_type: Union[None, str] = None + + init_timeout_seconds: Union[None, int] = 3600 + + make_vocab_size_divisible_by: int = 32 + model_parallel_size: int = 1 + use_slow_tokenizer: bool = False + world_size: int = 8 + + # max prompt string length and whole str length + max_prompt_length: Union[None, int] = 2048 + max_length: Union[None, int] = 4096 + + # num of process processing dataset + dataset_num_proc: int = 1 + + # model_dtype[float16, bfloat16, float] for loading + dtype: str = "bfloat16" + + # instrumentation + disable_tqdm: bool = False + sanity_check: bool = False + + # debug argument for distributed training + # "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + # "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + ignore_bias_buffers: bool = True + + def dict(self): + return {k: str(v) for k, v in asdict(self).items()} diff --git a/requirements.txt b/requirements.txt index ee5577f..189518b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ numpy==1.23.5 -pandas==1.5.3 -torch==2.0.1 +pandas==2.2.1 +torch==2.1.0 tensorboard==2.11.0 -deepspeed==0.9.3 -transformers==4.36.0 -accelerate==0.23.0 -peft==0.7.0 -BitsAndBytes==0.40.2 -xformers==0.0.21 +deepspeed==0.14.0 +transformers==4.44.2 +accelerate==0.31.0 +peft==0.10.0 +BitsAndBytes==0.43.0 +xformers==0.0.22.post7 +datasets +ftfy packaging einops sentencepiece