transformers Trainer训练模型实践踩坑(含多图推理data加载、多机多卡、多loss日志等)
发布时间:
概述
使用开源Qwen3-VL-8B-instruct模型+打分头,在一个有标注小规模数据集上训练分类头
数据集构建
数据集组织使用的是webdataset。
多图推理在processor时遇到了问题(上一篇blog提到)
最终做法:
message = [
{
"role": "user",
"content": [
{"type": "image", "image": images[0]},
{"type": "image", "image": images[1]},
{
"type": "text",
"text": prompt,
},
],
}
]
text_input = processor.apply_chat_template(
message,
tokenize=False,
)
image_inputs, _ = process_vision_info(message, image_patch_size=16)
inputs = processor(
images=image_inputs,
text=[text_input],
padding=True,
padding_side="left",
return_tensors="pt",
).to(torch.bfloat16)
process_vision_info能够提取messages图像,看似是多此一举(为什么不直接用images?),实际上它还有自动将图像resize到patch大小的倍数的功能。
padding_side=left是为了最后一个token的hidden_state对应整个序列的特征
processor输出的结果是一个BatchFeatures,包含
input_ids (1,L),L为tokenize后(包含文本和图像)的序列长度。以及标识图像开始结尾的一些special token(但是图像token实际上是占位符,输出就可以发现都是相同的一个token id,因为此时还没有真正的对图像进行tokenize,这一过程在model.forward()内部才会正式执行)。attention_mask (1, L)不过多解释,注意力maskpixel_values: (num_images * size0, size1)图像真正的像素值,但是shape并不严格等于图像大小(甚至想去甚远),我还没搞明白是为什么。值得注意的是,在Qwen3VL中,同一个sample中的多张图像会堆叠在第1维,而不是第0维或者另开一维。image_grid_thw: (num_images, 3)表示图像是如何被分patch的,模型内部会根据这个向量来分patch,tokenize,并对应占位符。
而当我们需要进行批次聚合时就会发现,模型目标的批次输入格式是不同的:
input_ids (B,L)attention_mask (B,L)pixel_values: (B, num_images * size0, size1)image_grid_thw: (B*num_images, 3)
坑点1:image_grid_thw的坑点就在于它没有第0维 B,而是直接乘以num_images。 因为在Qwen模型内部代码中硬编码要取这个变量的第1维(3这一维),所以它总共只能有两维。
坑点2:如果只传一个sample给processor,那么它生成的pixel_values没有第一维的1。
这也就意味着,我们在写批次聚合函数collate_fn时,pixel_values要用torch.stack(),而另外三个是torch.concat()。
关于labels要放在哪里,我们后面来说。
聚合函数sample:
# ==================== 自定义collate_fn ====================
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
自定义collate_fn,聚合输入
"""
first = features[0]
batch = {}
for k, v in first.items():
if v is not None and isinstance(v, dict):
batch[k] = collate_fn([f[k] for f in features])
elif v is not None and not isinstance(v, str):
if k in ["image_grid_thw", "input_ids", "attention_mask"]:
batch[k] = torch.concatenate([f[k] for f in features], dim=0)
else:
batch[k] = torch.stack([f[k] for f in features], dim=0)
else:
batch[k] = [f[k] for f in features]
return batch
多机多卡数据加载,训练
这部分其实也是慢慢摸索出来的
首先作为一个新手踩了一个大坑:用Trainer分布式训练,尽管它已经包装好了,但运行训练脚本必须使用torchrun命令而不是python直接执行,后者进行的DDP是错误的(尽管也能训,每张卡都在用,但是并没有实际利用到多卡并行加速)。
我的TrainingArguments详细配置:
training_args = TrainingArguments(
output_dir=str(save_dir),
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
logging_steps=100,
save_steps=2000,
save_strategy="steps",
eval_steps=2000,
eval_strategy="steps",
metric_for_best_model="loss",
greater_is_better=False,
report_to="wandb",
max_steps=num_steps,
learning_rate=args.lr,
ddp_find_unused_parameters=False, # 重要,不然多卡加载会有问题
optim="adamw_torch",
lr_scheduler_type="polynomial",
lr_scheduler_kwargs={"power": 3.0},
warmup_steps=500,
save_total_limit=5,
weight_decay=0.02,
label_names=["score_labels", "type_labels"],
)
if args.ckpt_path is not None:
ckpt_path = Path(args.ckpt_path) / "model.safetensors"
checkpoint = load_file(str(ckpt_path))
qwen_model.load_state_dict(state_dict=checkpoint, strict=True)
print(f"Loaded model from {str(ckpt_path)}")
else:
print("No checkpoint path provided, starting from the beginning")
# 创建训练器
trainer = ScorerTrainer(
model=qwen_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
collate_fn=collate_fn,
)
由于我使用了webdataset,为了实现分布式数据读取,需要指定分片方式:
dataset = WebDataset(
tar_files,
shardshuffle=False,
workersplitter=wds.split_by_worker,
nodesplitter=wds.split_by_node,
)
我使用了wds.WebLoader作为加载器,而不是使用默认的加载器,在Trainer的子类中需要重写方法:
def get_train_dataloader(self) -> WebLoader:
dataloader_params = {
"batch_size": self.args.per_device_train_batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
dataloader = self.accelerator.prepare(
WebLoader(self.train_dataset, **dataloader_params)
)
return dataloader
def get_eval_dataloader(self, eval_dataset: WebDataset) -> WebLoader:
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"sampler": self._get_eval_sampler(eval_dataset),
}
dataloader = self.accelerator.prepare(
WebLoader(eval_dataset, **dataloader_params)
)
return dataloader
其中必须使用accelerator.prepare()来准备数据加载器(不然无法正常多卡加载)。
然后碰上一个更坑的事:webdataset多卡必须按照shard切分,当验证集比较小(小于卡数时)就爆了
- (解决方案是把验证集的resampled打开)
多loss日志
默认的Trainer是不支持多loss打日志的,并且按照他那个写法,train和eval的labels是完全割裂的,非常难用。
train的labels加载和eval的labels加载颇有自相矛盾的意味
- 两个不同的函数
compute_loss和compute_metrics,前者只支持返回单个loss,后者可以看多个,但是传入的参数很奇怪 train_loop中会从inputspop("labels")字段,根据这个字段是否存在判断loss如何计算- 而eval中有一个独立的参数
label_names,要求用户指明labels的种类并且放在inputs的第一层,而用户可重写的compute_metrics函数接收到的居然是一个numpy,甚至无法重载它的聚合函数 - 在eval中重写就变得非常困难
- 先前的做法是重写
compute_loss_func,将log代码注入在这,但是这会和eval的部分冲突
最终解法:子类重写compute_loss和compute_metrics。
在model.forward中用一个类同时将loss封装起来。(我有两个task,接了两个输出头,两个loss加权为总loss)
@dataclass
class ScorerOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
loss_1: Optional[torch.FloatTensor] = None
loss_2: Optional[torch.FloatTensor] = None
outputs1: torch.FloatTensor = None
outputs2: torch.FloatTensor = None
在TrainingArguments中配置好label_names,在eval 调用compute_metrics时打日志。好处是eval的时候不会乱gather,训练过程就只能看到总loss了
至于训练过程,Tricky的办法是同时做封装,compute_loss可以注入log代码,通过self.state.global_step来算步数,compute_metric可以直接返回字典。
关键在于compute_loss内区分当前是在train还是在eval(eval的时候只在compute_metrics中打日志)。这里的办法是根据参数return_outputs来判断(eval时return_outputs为True)(这也是重载compute_loss而不重载更小的compute_loss_func的原因)
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
num_items_in_batch: Optional[torch.Tensor] = None,
):
"""
Rewrite Trainer.compute_loss
"""
if self.model_accepts_loss_kwargs:
kwargs = {}
if num_items_in_batch is not None:
kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss = outputs["loss"]
if return_outputs == False: # in Training loop, return_outputs is False
if self.state.global_step % self.log_step == 0:
self.log(
{
"loss_1": outputs["loss_1"],
"loss_2": outputs["loss_2"],
},
)
if (
self.args.average_tokens_across_devices
and (self.model_accepts_loss_kwargs or self.compute_loss_func)
and num_items_in_batch is not None
):
loss *= (
self.accelerator.num_processes
if self.args.n_gpu <= 1
else self.args.n_gpu
)
return (loss, outputs) if return_outputs else loss
def compute_metrics(self, eval_prediction: EvalPrediction):
"""计算验证指标
聚合一整个验证集的指标,需要mean
"""
loss_1 = np.mean(eval_prediction.predictions[0])
loss_2 = np.mean(eval_prediction.predictions[1])
return {
"loss_1": loss_1,
"loss_2": loss_2,
}
