找回密码
 立即注册
首页 业界区 业界 pytorch入门 - 修改huggingface大模型配置参数 ...

pytorch入门 - 修改huggingface大模型配置参数

烯八 2025-6-10 12:12:53
介绍

Hugging Face的Transformers库提供了大量预训练模型,但有时我们需要修改这些模型的默认参数来适应特定任务。
本文将详细介绍如何修改BERT模型的最大序列长度(max_position_embeddings)参数,并解释相关原理和实现细节。
原理

BERT等Transformer模型对输入序列长度有固定限制,这主要由位置编码(position embeddings)决定。
原始BERT-base-chinese模型的max_position_embeddings为512,意味着它最多只能处理512个token的输入。当我们需要处理更长的文本时,必须修改这一参数。
修改过程涉及三个关键步骤:

  • 调整模型配置中的max_position_embeddings值
  • 替换位置嵌入层(position_embeddings)为新尺寸
  • 初始化新位置嵌入层的权重(复制原有权重,其余随机初始化)
实现代码详解

下面我们逐行分析实现代码:
1. 数据集准备 (news_finetuing_data_set.py)
  1. from datasets import load_dataset, load_from_disk
  2. from torch.utils.data import Dataset
  3. class MyDataset(Dataset):
  4.     def __init__(self, split):
  5.         # 指定CSV文件路径,支持train/test/validation三种分割
  6.         data_file = rf"cache\datasets\csv\THUCNewsText\{split}.csv"
  7.         self.dataset = load_dataset(
  8.             "csv",
  9.             data_files={split: data_file},
  10.             split=split if split in ["train", "test", "validation"] else "train",
  11.         )
  12.     def __len__(self):
  13.         return len(self.dataset)  # 返回数据集样本数量
  14.     def __getitem__(self, idx):
  15.         return self.dataset[idx]["text"], self.dataset[idx]["label"]  # 返回文本和标签
复制代码
2. 模型修改 (news_finetuing_net.py)
  1. from transformers import BertModel, BertConfig
  2. import torch
  3. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  4. # 1. 加载预训练模型和配置
  5. model = BertModel.from_pretrained(
  6.     "bert-base-chinese", cache_dir="./cache/bertbasechinese"
  7. ).to(device)
  8. # 2. 修改max_position_embeddings配置
  9. model.config.max_position_embeddings = 1500
  10. # 3. 替换position_embeddings层
  11. old_embeddings = model.embeddings.position_embeddings
  12. new_embeddings = torch.nn.Embedding(1500, old_embeddings.embedding_dim)
  13. # 拷贝原有权重
  14. num = min(old_embeddings.weight.size(0), 1500)
  15. new_embeddings.weight.data[:num, :] = old_embeddings.weight.data[:num, :]
  16. model.embeddings.position_embeddings = new_embeddings
  17. # 4. 冻结除position_embeddings外的所有参数
  18. for name, param in pretrained.named_parameters():
  19.     if "embeddings.position_embeddings" in name:
  20.         param.requires_grad = True
  21.     else:
  22.         param.requires_grad = False
  23. class Model(torch.nn.Module):
  24.     def __init__(self):
  25.         super(Model, self).__init__()
  26.         self.classifier = torch.nn.Linear(768, 10)  # 添加分类头
  27.     def forward(self, input_ids, attention_mask, token_type_ids):
  28.         position_ids = (
  29.             torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device)
  30.             .unsqueeze(0)
  31.             .expand_as(input_ids)
  32.         )
  33.         outputs = pretrained(
  34.             input_ids=input_ids,
  35.             attention_mask=attention_mask,
  36.             token_type_ids=token_type_ids,
  37.             position_ids=position_ids,
  38.         )
  39.         cls_output = outputs.last_hidden_state[:, 0]  # 取[CLS] token的输出
  40.         out = self.classifier(cls_output)
  41.         return out
复制代码
3. 训练过程 (news_finetuing_train.py)
  1. import torch
  2. from news_finetuing_data_set import MyDataset
  3. from torch.utils.data import DataLoader
  4. from news_finetuing_net import Model
  5. from transformers import BertTokenizer
  6. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  7. EPOCH = 100
  8. # 加载分词器并设置最大长度
  9. token = BertTokenizer.from_pretrained(
  10.     "bert-base-chinese",
  11.     cache_dir="./cache/tokenizer/bert-base-chinese",
  12. )
  13. token.model_max_length = 1500  # 设置分词器最大长度
  14. def collate_fn(batch):
  15.     # 数据处理函数
  16.     sentes = [item[0] for item in batch]
  17.     labels = [item[1] for item in batch]
  18.     data = token.batch_encode_plus(
  19.         sentes,
  20.         truncation=True,
  21.         padding="max_length",
  22.         max_length=1500,
  23.         return_tensors="pt",
  24.         return_length=True,
  25.     )
  26.     # 返回模型需要的各种输入
  27.     return (
  28.         data["input_ids"],
  29.         data["attention_mask"],
  30.         data["token_type_ids"],
  31.         torch.LongTensor(labels),
  32.     )
  33. # 创建数据集和DataLoader
  34. train_dataset = MyDataset(split="train")
  35. val_dateset = MyDataset(split="validation")
  36. train_loader = DataLoader(
  37.     train_dataset,
  38.     batch_size=32,
  39.     shuffle=True,
  40.     drop_last=True,
  41.     collate_fn=collate_fn,
  42. )
  43. val_loader = DataLoader(
  44.     val_dateset,
  45.     batch_size=32,
  46.     shuffle=False,
  47.     drop_last=True,
  48.     collate_fn=collate_fn,
  49. )
  50. # 训练主循环
  51. if __name__ == "__main__":
  52.     model = Model().to(device)
  53.     optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
  54.     loss_func = torch.nn.CrossEntropyLoss()
  55.     for epoch in range(EPOCH):
  56.         model.train()
  57.         for step, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
  58.             # 数据移动到设备
  59.             input_ids = input_ids.to(device)
  60.             attention_mask = attention_mask.to(device)
  61.             token_type_ids = token_type_ids.to(device)
  62.             labels = labels.to(device)
  63.             # 前向传播和反向传播
  64.             outputs = model(input_ids, attention_mask, token_type_ids)
  65.             loss = loss_func(outputs, labels)
  66.             optimizer.zero_grad()
  67.             loss.backward()
  68.             optimizer.step()
  69.             # 打印训练信息
  70.             if step % 5 == 0:
  71.                 out = outputs.argmax(dim=1)
  72.                 acc = (out == labels).sum().item() / len(labels)
  73.                 print(f"Epoch: {epoch + 1}/{EPOCH}, Step: {step + 1}/{len(train_loader)}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")
  74.         # 保存模型
  75.         torch.save(model.state_dict(), f"./model/news_finetuning_epoch_{epoch}.pth")
  76.         print(f"epoch {epoch} 保存成功")
复制代码
关键点解释

模型修改部分


  • model.config.max_position_embeddings = 1500 - 修改配置中的最大位置嵌入数
  • 创建新的位置嵌入层时,我们保留了原始嵌入维度(embedding_dim),只扩展了位置数量
  • 权重初始化策略是复制原有512个位置的权重,剩余位置使用随机初始化
训练策略


  • 我们冻结了除位置嵌入外的所有BERT参数,只训练位置嵌入和新添加的分类头
  • 这种策略在长文本微调中很常见,可以防止过拟合
数据处理


  • 分词器也需要设置model_max_length以匹配新的序列长度
  • collate_fn函数确保所有输入都被填充/截断到1500的长度
总结

本文详细介绍了如何修改Hugging Face模型的max_position_embeddings参数,包括原理说明和完整代码实现。这种方法可以扩展到其他参数的修改,为定制化预训练模型提供了参考。关键点在于正确修改配置、替换相应层并合理初始化参数。

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册