fine-tune llama2 chat with lora - Wed, Sep 13, 2023
fine-tune llama2 chat with lora
1. 数据集
官方说明 https://huggingface.co/blog/llama2#how-to-prompt-llama-2
1.1 数据格式
- 单轮对话
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_message }} [/INST]
- 多轮对话
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>
{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
1.2 数据集格式
Data format | Loading script | Example |
---|---|---|
CSV & TSV | csv | load_dataset(“csv”, data_files=“my_file.csv”) |
Text files | text | load_dataset(“text”, data_files=“my_file.txt”) |
JSON & JSON Lines | json | load_dataset(“json”, data_files=“my_file.jsonl”) |
Pickled DataFrames | pandas | load_dataset(“pandas”, data_files=“my_dataframe.pkl”) |
jsonl数据集中每条数据的格式
{ "text": "text-for-model-to-predict" }
2. 训练
2.1 必要包
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
2.2 加载数据集
dataset = load_dataset(
'json',
data_files=dataset_name,
split='train'
)
- json是加载数据集的格式
- data_files是数据集的文件名
- split=‘train’表示数据集全为训练集
2.3 加载原始模型
2.3.1 量化参数
量化是通过减小类型,加快训练速度,减小模型大小的一种方法。缺点是会降低精度,降低准确率。 模型中数据存储的是浮点类型,F64,F32,F16,F8,NF4(规范化浮点数 4), FP4(纯 FP4). NF4性能更好。
- bnb_4bit_compute_dtype计算量化时用的数据类型(默认torch.float32)
- bnb_4bit_use_double_quant 第一个量化之后使用第二个量化来为每个参数节省额外的 0.4 位.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # fp4
bnb_4bit_compute_dtype=torch.float16,
# bnb_4bit_use_double_quant=True,
)
2.3.2 基础模型加载
device_map = "auto" # auto | balanced or balanced_low_0 | sequential |
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
local_files_only=True,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
# use_auth_token=True
)
# https://huggingface.co/transformers/v2.9.1/main_classes/model.html
# fine tuning should update params, should not use cache
base_model.config.use_cache = False
# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1
2.3.3 分词器加载
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
2.3.4 Lora配置
- lora_alpha 缩放系数
- lora_dropout Dropout 系数
- r 秩大小
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
2.3.5 训练参数
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=10,
max_steps=500
)
max_seq_length = 1024 #4K
trainer = SFTTrainer(
model=base_model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_args,
)
trainer.train()
trainer.model.save_pretrained(output_dir)