一个参数为7B的类似GPT模型,使用了DiscoPOP在公开合成数据集集上微调
21 Pulls 更新于2个月前
更新于2个月前
2个月前
e3682833e24e · 6.9GB
README
DiscoPOP-zephyr-7b-gemma
此模型是在argilla/dpo-mix-7k数据集上对HuggingFaceH4/zephyr-7b-gemma-sft-v0.1进行微调的版本。
此模型来自论文“Discovering Preference Optimization Algorithms with and for Large Language Models”
在此阅读相关博客文章!
生成它的代码库在这里: https://github.com/SakanaAI/DiscoPOP
模型描述
该模型在训练上与HuggingFaceH4/zephyr-7b-gemma-v0.1相同,除了不使用直接偏好优化(DPO),而是使用DiscoPOP。
DiscoPOP是我们的发现偏好优化算法,定义如下
def log_ratio_modulated_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> torch.FloatTensor:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios
# Modulate the mixing coefficient based on the log ratio magnitudes
log_ratio_modulation = torch.sigmoid(logits)
logistic_component = -F.logsigmoid(self.beta * logits)
exp_component = torch.exp(-self.beta * logits)
# Blend between logistic and exponential component based on log ratio modulation
losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation
return losses
训练超参数
以下超参数在训练期间使用
- learning_rate: 5e-07
- train_batch_size: 2
- eval_batch_size: 4
- seed: 42
- distributed_type: multi-GPU
- num_devices: 8
- gradient_accumulation_steps: 8
- total_train_batch_size: 128
- total_eval_batch_size: 32
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: cosine
- lr_scheduler_warmup_ratio: 0.1
- num_epochs: 2
框架版本
- Transformers 4.40.1
- Pytorch 2.1.2+cu121
- 数据集 2.19.0
- 分词器 0.19.1
Zephyr 7B Gemma 模型卡
Zephyr 是一系列经过训练以充当有用助手的语言模型。Zephyr 7B Gemma 是该系列中的第三个模型,是对 google/gemma-7b
的微调版本,该版本使用直接偏好优化(DPO)在公开可用的合成数据集上进行训练。您可以通过 Alignment Handbook 中提供的配方来重现该模型的训练。
模型描述
- 模型类型:一个参数量为 7B 的类似 GPT 模型,在公开可用和合成数据集的混合上进行微调。
- 语言(NLP):主要是英语
- 许可证:Gemma 使用条款
- 微调自模型:google/gemma-7b
模型来源
- 仓库:https://github.com/huggingface/alignment-handbook
- 演示:https://hugging-face.cn/spaces/HuggingFaceH4/zephyr-7b-gemma-chat
性能
模型 | MT Bench⬇️ | IFEval |
---|---|---|
zephyr-7b-gemma-v0.1 | 7.81 | 28.76 |
zephyr-7b-beta | 7.34 | 43.81 |
google/gemma-7b-it | 6.38 | 38.01 |
模型 | AGIEval | GPT4All | TruthfulQA | BigBench | 平均值 ⬇️ |
---|---|---|---|---|---|
zephyr-7b-beta | 37.52 | 71.77 | 55.26 | 39.77 | 51.08 |
zephyr-7b-gemma-v0.1 | 34.22 | 66.37 | 52.19 | 37.10 | 47.47 |
mlabonne/Gemmalpaca-7B | 21.6 | 40.87 | 44.85 | 30.49 | 34.45 |
google/gemma-7b-it | 21.33 | 40.84 | 41.70 | 30.25 | 33.53 |
AGIEval、GPT4All、TruthfulQA、BigBench 的详细信息
### AGIEval | 任务 | 版本 | 指标 | 值 | | 标准误差 | |------------------------------|------:|--------|----:|---|-----:| |agieval_aqua_rat | 0 | acc | 21.65 | ± | 2.59 | | | acc_norm | 25.20 | ± | 2.73 | |agieval_logiqa_en | 0 | acc | 34.72 | ± | 1.87 | | | acc_norm | 35.94 | ± | 1.88 | |agieval_lsat_ar | 0 | acc | 19.57 | ± | 2.62 | | | acc_norm | 21.74 | ± | 2.73 | |agieval_lsat_lr | 0 | acc | 30.59 | ± | 2.04 | | | acc_norm | 32.55 | ± | 2.08 | |agieval_lsat_rc | 0 | acc | 49.07 | ± | 3.05 | | | acc_norm | 42.75 | ± | 3.02 | |agieval_sat_en | 0 | acc | 54.85 | ± | 3.48 | | | acc_norm | 53.40 | ± | 3.48 | |agieval_sat_en_without_passage | 0 | acc | 37.38 | ± | 3.38 | | | acc_norm | 33.98 | ± | 3.31 | |agieval_sat_math | 0 | acc | 30.91 | ± | 3.12 | | | acc_norm | 28.18 | ± | 3.04 | 平均值:34.22% ### GPT4All | 任务 | 版本 | 指标 | 值 | | 标准误差 | |-------------|------:|--------|----:|---|-----:| |arc_challenge | 0 | acc | 49.15 | ± | 1.46 | | | acc_norm | 52.47 | ± | 1.46 | |arc_easy | 0 | acc | 77.44 | ± | 0.86 | | | acc_norm | 74.75 | ± | 0.89 | |boolq | 1 | acc | 79.69 | ± | 0.70 | | | acc_norm | 78.00 | ± | 0.41 | |hellaswag | 0 | acc | 60.59 | ± | 0.49 | | | acc_norm | 78.00 | ± | 0.41 | |openbookqa | 0 | acc | 29.20 | ± | 2.04 | | | acc_norm | 37.80 | ± | 2.17 | |piqa | 0 | acc | 76.82 | ± | 0.98 | | | acc_norm | 77.80 | ± | 0.97 | |winogrande | 0 | acc | 64.09 | ± | 1.35 | | | acc_norm | 64.09 | ± | 1.35 | 平均值:66.37% ### TruthfulQA | 任务 | 版本 | 指标 | 值 | | 标准误差 | |-------------|------:|------|----:|---|-----:| |truthfulqa_mc | 1 | mc1 | 35.74 | ± | 1.68 | | | mc2 | 52.19 | ± | 1.59 | 平均值:52.19% ### Bigbench | 任务 | 版本 | 指标 | 值 | | 标准误差 | |------------------------------------------------|------:|---------------------|----:|---|-----:| |bigbench_causal_judgement | 0 | multiple_choice_grade | 53.68 | ± | 3.63 | |bigbench_date_understanding | 0 | multiple_choice_grade | 59.89 | ± | 2.55 | |bigbench_disambiguation_qa | 0 | multiple_choice_grade | 30.23 | ± | 2.86 | |bigbench_geometric_shapes | 0 | multiple_choice_grade | 11.42 | ± | 1.68 | | | exact_str_match | 0.00 | ± | 0.00 | |bigbench_logical_deduction_five_objects | 0 | multiple_choice_grade | 28.40 | ± | 2.02 | |bigbench_logical_deduction_seven_objects | 0 | multiple_choice_grade | 19.14 | ± | 1.49 | |bigbench_logical_deduction_three_objects | 0 | multiple_choice_grade | 44.67 | ± | 2.88 | |bigbench_movie_recommendation | 0 | multiple_choice_grade | 26.80 | ± | 1.98 | |bigbench_navigate | 0 | multiple_choice_grade | 50.00 | ± | 1.58 | |bigbench_reasoning_about_colored_objects | 0 | multiple_choice_grade | 52.75 | ± | 1.12 | |bigbench_ruin_names | 0 | multiple_choice_grade | 33.04 | ± | 2.22 | |bigbench_salient_translation_error_detection | 0 | multiple_choice_grade | 33.37 | ± | 1.49 | |bigbench_snarks | 0 | multiple_choice_grade | 48.62 | ± | 3.73 | |bigbench_sports_understanding | 0 | multiple_choice_grade | 58.11 | ± | 1.57 | |bigbench_temporal_sequences | 0 | multiple_choice_grade | 37.20 | ± | 1.53 | |bigbench_tracking_shuffled_objects_five_objects | 0 | multiple_choice_grade | 20.08 | ± | 1.13 | |bigbench_tracking_shuffled_objects_seven_objects | 0 | multiple_choice_grade | 15.77 | ± | 0.87 | |bigbench_tracking_shuffled_objects_three_objects | 0 | multiple_choice_grade | 44.67 | ± | 2.88 | 平均值:37.1%预期用途与限制
该模型最初是在 DEITA 10K 数据集上微调的,该数据集包含由 ChatGPT 生成的各种合成对话。
然后我们进一步将模型与🤗 TRL的 DPOTrainer
对齐,在argilla/dpo-mix-7k 数据集上进行,该数据集包含7k个提示和模型完成内容,这些内容按GPT-4进行排序。因此,该模型可用于聊天,您可以检查我们的演示以测试其功能。
以下是如何使用🤗 Transformers的pipeline()
函数运行模型的步骤。
# pip install transformers>=4.38.2
# pip install accelerate
import torch
from transformers import pipeline
pipe = pipeline(
"text-generation",
model="HuggingFaceH4/zephyr-7b-gemma-v0.1",
device_map="auto",
torch_dtype=torch.bfloat16,
)
messages = [
{
"role": "system",
"content": "", # Model not yet trained for follow this
},
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
outputs = pipe(
messages,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
stop_sequence="<|im_end|>",
)
print(outputs[0]["generated_text"][-1]["content"])
# It is not possible for a human to eat a helicopter in one sitting, as a
# helicopter is a large and inedible machine. Helicopters are made of metal,
# plastic, and other materials that are not meant to be consumed by humans.
# Eating a helicopter would be extremely dangerous and would likely cause
# serious health problems, including choking, suffocation, and poisoning. It is
# important to only eat food that is safe and intended for human consumption.
偏差、风险和局限性
Zephyr 7B Gemma在RLHF阶段没有与人类的偏好对齐以保证安全,也没有像ChatGPT那样部署带响应过滤的在线过滤,因此该模型可能会产生问题输出(尤其是在被要求这样做时)。另外,关于用于训练基模型(google/gemma-7b
)的语料库的规模和组成尚不清楚,但可能包括网络数据和技术来源(如书籍和代码)的混合。参见StarCoder2模型卡片以获取示例。
训练和评估数据
此模型是在argilla/dpo-mix-7k数据集上对HuggingFaceH4/zephyr-7b-gemma-sft-v0.1进行微调的版本。
它在评估集上达到以下结果
- 损失:0.4695
- 奖励/选择:-3.3746
- 奖励/拒绝:-4.9715
- 奖励/准确性:0.7188
- 奖励/边缘:1.5970
- 对数概率/拒绝:-459.4853
- 对数概率/选择:-429.9115
- 得分/拒绝:86.4684
- 得分/选择:92.8200
训练超参数
以下超参数在训练期间使用
- learning_rate: 5e-07
- train_batch_size: 2
- eval_batch_size: 4
- seed: 42
- distributed_type: multi-GPU
- num_devices: 8
- gradient_accumulation_steps: 8
- total_train_batch_size: 128
- total_eval_batch_size: 32
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: cosine
- lr_scheduler_warmup_ratio: 0.1
- num_epochs: 2
训练结果
训练损失 | 时期 | 步骤 | 验证损失 | 奖励/选择 | 奖励/拒绝 | 奖励/准确性 | 奖励/边缘 | 对数概率/拒绝 | 对数概率/选择 | 得分/拒绝 | 得分/选择 |
---|---|---|---|---|---|---|---|---|---|---|---|
0.1923 | 1.9 | 100 | 0.4736 | -3.4575 | -4.9556 | 0.75 | 1.4980 | -459.1662 | -431.5707 | 86.3863 | 92.7360 |
框架版本
- Transformers 4.39.0.dev0
- Pytorch 2.1.2+cu121
- Datasets 2.14.6
- Tokenizers 0.15.1
引用信息
如果您在自己的工作中发现这个模型有用,请考虑引用Zephyr技术报告
@misc{tunstall2023zephyr,
title={Zephyr: Direct Distillation of LM Alignment},
author={Lewis Tunstall and Edward Beeching and Nathan Lambert and Nazneen Rajani and Kashif Rasul and Younes Belkada and Shengyi Huang and Leandro von Werra and Clémentine Fourrier and Nathan Habib and Nathan Sarrazin and Omar Sanseviero and Alexander M. Rush and Thomas Wolf},
year={2023},
eprint={2310.16944},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
您还希望引用此模型的创作者
@misc{zephyr_7b_gemma,
author = {Lewis Tunstall and Philipp Schmid},
title = {Zephyr 7B Gemma},
year = {2024},
publisher = {Hugging Face},
journal = {Hugging Face repository},
howpublished = {\url{https://hugging-face.cn/HuggingFaceH4/zephyr-7b-gemma-v0.1}}
}
开放LLM排行榜评估结果
详细结果可在这里找到
指标 | 值 |
---|---|
平均值 | 62.41 |
AI2推理挑战(25-Shot) | 58.45 |
HellaSwag(10-Shot) | 83.48 |
MMLU(5-Shot) | 60.68 |
TruthfulQA(0-shot) | 52.07 |
Winogrande(5-shot) | 74.19 |
GSM8k(5-shot) | 45.56 |