transformers Trainer训练模型实践踩坑(含多图推理data加载、多机多卡、多loss日志等)

3 分钟阅读时长

发布时间:


概述

使用开源Qwen3-VL-8B-instruct模型+打分头,在一个有标注小规模数据集上训练分类头

数据集构建

数据集组织使用的是webdataset。

多图推理在processor时遇到了问题(上一篇blog提到)

最终做法:

message = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": images[0]},
            {"type": "image", "image": images[1]},
            {
                "type": "text",
                "text": prompt,
            },
        ],
    }
]
text_input = processor.apply_chat_template(
    message,
    tokenize=False,
)
image_inputs, _ = process_vision_info(message, image_patch_size=16)

inputs = processor(
    images=image_inputs,
    text=[text_input],
    padding=True,
    padding_side="left",
    return_tensors="pt",
).to(torch.bfloat16)

process_vision_info能够提取messages图像,看似是多此一举(为什么不直接用images?),实际上它还有自动将图像resize到patch大小的倍数的功能。

padding_side=left是为了最后一个token的hidden_state对应整个序列的特征

processor输出的结果是一个BatchFeatures,包含

  • input_ids (1,L),L为tokenize后(包含文本和图像)的序列长度。以及标识图像开始结尾的一些special token(但是图像token实际上是占位符,输出就可以发现都是相同的一个token id,因为此时还没有真正的对图像进行tokenize,这一过程在model.forward()内部才会正式执行)。
  • attention_mask (1, L) 不过多解释,注意力mask
  • pixel_values: (num_images * size0, size1) 图像真正的像素值,但是shape并不严格等于图像大小(甚至想去甚远),我还没搞明白是为什么。值得注意的是,在Qwen3VL中,同一个sample中的多张图像会堆叠在第1维,而不是第0维或者另开一维。
  • image_grid_thw: (num_images, 3) 表示图像是如何被分patch的,模型内部会根据这个向量来分patch,tokenize,并对应占位符。

而当我们需要进行批次聚合时就会发现,模型目标的批次输入格式是不同的:

  • input_ids (B,L)
  • attention_mask (B,L)
  • pixel_values: (B, num_images * size0, size1)
  • image_grid_thw: (B*num_images, 3)

坑点1:image_grid_thw的坑点就在于它没有第0维 B,而是直接乘以num_images。 因为在Qwen模型内部代码中硬编码要取这个变量的第1维(3这一维),所以它总共只能有两维。

坑点2:如果只传一个sample给processor,那么它生成的pixel_values没有第一维的1。

这也就意味着,我们在写批次聚合函数collate_fn时,pixel_values要用torch.stack(),而另外三个是torch.concat()

关于labels要放在哪里,我们后面来说。

聚合函数sample:

# ==================== 自定义collate_fn ====================
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    自定义collate_fn,聚合输入
    """
    first = features[0]
    batch = {}

    for k, v in first.items():
        if v is not None and isinstance(v, dict):
            batch[k] = collate_fn([f[k] for f in features])
        elif v is not None and not isinstance(v, str):
            if k in ["image_grid_thw", "input_ids", "attention_mask"]:
                batch[k] = torch.concatenate([f[k] for f in features], dim=0)
            else:
                batch[k] = torch.stack([f[k] for f in features], dim=0)
        else:
            batch[k] = [f[k] for f in features]

    return batch

多机多卡数据加载,训练

这部分其实也是慢慢摸索出来的

首先作为一个新手踩了一个大坑:用Trainer分布式训练,尽管它已经包装好了,但运行训练脚本必须使用torchrun命令而不是python直接执行,后者进行的DDP是错误的(尽管也能训,每张卡都在用,但是并没有实际利用到多卡并行加速)。

我的TrainingArguments详细配置:

training_args = TrainingArguments(
    output_dir=str(save_dir),
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    logging_steps=100,
    save_steps=2000,
    save_strategy="steps",
    eval_steps=2000,
    eval_strategy="steps",
    metric_for_best_model="loss",
    greater_is_better=False,
    report_to="wandb",
    max_steps=num_steps,
    learning_rate=args.lr,
    ddp_find_unused_parameters=False, # 重要,不然多卡加载会有问题
    optim="adamw_torch",
    lr_scheduler_type="polynomial",
    lr_scheduler_kwargs={"power": 3.0},
    warmup_steps=500,
    save_total_limit=5,
    weight_decay=0.02,
    label_names=["score_labels", "type_labels"],
)
if args.ckpt_path is not None:
    ckpt_path = Path(args.ckpt_path) / "model.safetensors"
    checkpoint = load_file(str(ckpt_path))
    qwen_model.load_state_dict(state_dict=checkpoint, strict=True)
    print(f"Loaded model from {str(ckpt_path)}")
else:
    print("No checkpoint path provided, starting from the beginning")
# 创建训练器
trainer = ScorerTrainer(
    model=qwen_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    collate_fn=collate_fn,
)

由于我使用了webdataset,为了实现分布式数据读取,需要指定分片方式:

dataset = WebDataset(
    tar_files,
    shardshuffle=False,
    workersplitter=wds.split_by_worker,
    nodesplitter=wds.split_by_node,
)

我使用了wds.WebLoader作为加载器,而不是使用默认的加载器,在Trainer的子类中需要重写方法:

def get_train_dataloader(self) -> WebLoader:
    dataloader_params = {
        "batch_size": self.args.per_device_train_batch_size,
        "collate_fn": self.data_collator,
        "num_workers": self.args.dataloader_num_workers,
        "pin_memory": self.args.dataloader_pin_memory,
        "persistent_workers": self.args.dataloader_persistent_workers,
    }
    dataloader = self.accelerator.prepare(
        WebLoader(self.train_dataset, **dataloader_params)
    )
    return dataloader

def get_eval_dataloader(self, eval_dataset: WebDataset) -> WebLoader:
    dataloader_params = {
        "batch_size": self.args.per_device_eval_batch_size,
        "collate_fn": self.data_collator,
        "num_workers": self.args.dataloader_num_workers,
        "pin_memory": self.args.dataloader_pin_memory,
        "persistent_workers": self.args.dataloader_persistent_workers,
        "sampler": self._get_eval_sampler(eval_dataset),
    }
    dataloader = self.accelerator.prepare(
        WebLoader(eval_dataset, **dataloader_params)
    )
    return dataloader

其中必须使用accelerator.prepare()来准备数据加载器(不然无法正常多卡加载)。

然后碰上一个更坑的事:webdataset多卡必须按照shard切分,当验证集比较小(小于卡数时)就爆了

  • (解决方案是把验证集的resampled打开)

多loss日志

默认的Trainer是不支持多loss打日志的,并且按照他那个写法,train和eval的labels是完全割裂的,非常难用。

train的labels加载和eval的labels加载颇有自相矛盾的意味

  • 两个不同的函数compute_losscompute_metrics,前者只支持返回单个loss,后者可以看多个,但是传入的参数很奇怪
  • train_loop中会从inputs pop("labels")字段,根据这个字段是否存在判断loss如何计算
  • 而eval中有一个独立的参数 label_names,要求用户指明labels的种类并且放在inputs的第一层,而用户可重写的compute_metrics函数接收到的居然是一个numpy,甚至无法重载它的聚合函数
  • 在eval中重写就变得非常困难
  • 先前的做法是重写compute_loss_func,将log代码注入在这,但是这会和eval的部分冲突

最终解法:子类重写compute_losscompute_metrics

model.forward中用一个类同时将loss封装起来。(我有两个task,接了两个输出头,两个loss加权为总loss)

@dataclass
class ScorerOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    loss_1: Optional[torch.FloatTensor] = None
    loss_2: Optional[torch.FloatTensor] = None
    outputs1: torch.FloatTensor = None
    outputs2: torch.FloatTensor = None

TrainingArguments中配置好label_names,在eval 调用compute_metrics时打日志。好处是eval的时候不会乱gather,训练过程就只能看到总loss了

至于训练过程,Tricky的办法是同时做封装,compute_loss可以注入log代码,通过self.state.global_step来算步数,compute_metric可以直接返回字典。

关键在于compute_loss内区分当前是在train还是在eval(eval的时候只在compute_metrics中打日志)。这里的办法是根据参数return_outputs来判断(eval时return_outputs为True)(这也是重载compute_loss而不重载更小的compute_loss_func的原因)

def compute_loss(
    self,
    model: nn.Module,
    inputs: dict[str, Union[torch.Tensor, Any]],
    return_outputs: bool = False,
    num_items_in_batch: Optional[torch.Tensor] = None,
):
    """
    Rewrite Trainer.compute_loss
    """
    if self.model_accepts_loss_kwargs:
        kwargs = {}
        if num_items_in_batch is not None:
            kwargs["num_items_in_batch"] = num_items_in_batch
        inputs = {**inputs, **kwargs}
    outputs = model(**inputs)
    # Save past state if it exists
    # TODO: this needs to be fixed and made cleaner later.
    if self.args.past_index >= 0:
        self._past = outputs[self.args.past_index]

    loss = outputs["loss"]
    if return_outputs == False:  # in Training loop, return_outputs is False
        if self.state.global_step % self.log_step == 0:
            self.log(
                {
                    "loss_1": outputs["loss_1"],
                    "loss_2": outputs["loss_2"],
                },
            )

    if (
        self.args.average_tokens_across_devices
        and (self.model_accepts_loss_kwargs or self.compute_loss_func)
        and num_items_in_batch is not None
    ):
        loss *= (
            self.accelerator.num_processes
            if self.args.n_gpu <= 1
            else self.args.n_gpu
        )

    return (loss, outputs) if return_outputs else loss

def compute_metrics(self, eval_prediction: EvalPrediction):
    """计算验证指标
    聚合一整个验证集的指标,需要mean
    """
    loss_1 = np.mean(eval_prediction.predictions[0])
    loss_2 = np.mean(eval_prediction.predictions[1])
    return {
        "loss_1": loss_1,
        "loss_2": loss_2,
    }
阅读量: 加载中...