kaisawind's blog
  • 关于
  • 所有帖子

fine-tune llama2 chat with lora - Wed, Sep 13, 2023

fine-tune llama2 chat with lora

1. 数据集

官方说明 https://huggingface.co/blog/llama2#how-to-prompt-llama-2

1.1 数据格式

  • 单轮对话
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>

{{ user_message }} [/INST]
  • 多轮对话
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>

{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]

1.2 数据集格式

Data formatLoading scriptExample
CSV & TSVcsvload_dataset(“csv”, data_files=“my_file.csv”)
Text filestextload_dataset(“text”, data_files=“my_file.txt”)
JSON & JSON Linesjsonload_dataset(“json”, data_files=“my_file.jsonl”)
Pickled DataFramespandasload_dataset(“pandas”, data_files=“my_dataframe.pkl”)

jsonl数据集中每条数据的格式

{ "text": "text-for-model-to-predict" }

2. 训练

2.1 必要包

from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer

2.2 加载数据集

dataset = load_dataset(
    'json', 
    data_files=dataset_name,
    split='train'
)
  • json是加载数据集的格式
  • data_files是数据集的文件名
  • split=‘train’表示数据集全为训练集

2.3 加载原始模型

2.3.1 量化参数

量化是通过减小类型,加快训练速度,减小模型大小的一种方法。缺点是会降低精度,降低准确率。 模型中数据存储的是浮点类型,F64,F32,F16,F8,NF4(规范化浮点数 4), FP4(纯 FP4). NF4性能更好。

  • bnb_4bit_compute_dtype计算量化时用的数据类型(默认torch.float32)
  • bnb_4bit_use_double_quant 第一个量化之后使用第二个量化来为每个参数节省额外的 0.4 位.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4", # fp4
    bnb_4bit_compute_dtype=torch.float16,
    # bnb_4bit_use_double_quant=True,
)

2.3.2 基础模型加载

device_map = "auto" # auto | balanced  or balanced_low_0 | sequential | 

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    local_files_only=True,
    quantization_config=bnb_config,
    device_map=device_map,
    trust_remote_code=True,
    # use_auth_token=True
)
# https://huggingface.co/transformers/v2.9.1/main_classes/model.html
# fine tuning should update params, should not use cache
base_model.config.use_cache = False
# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1 

2.3.3 分词器加载

tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

2.3.4 Lora配置

  • lora_alpha 缩放系数
  • lora_dropout Dropout 系数
  • r 秩大小
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

2.3.5 训练参数

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=10,
    max_steps=500
)

max_seq_length = 1024 #4K

trainer = SFTTrainer(
    model=base_model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_args,
)

trainer.train()

trainer.model.save_pretrained(output_dir)


辽ICP备2021007608号 | © 2025 | kaisawind

Facebook Twitter GitHub