跟着AI学AI - 诊断结论信息抽取 - 模型压缩与部署
在医疗场景中,电子病历里的诊断结论信息抽取是构建智能诊疗系统的关键环节。我们通常会用预训练大语言模型完成这项任务,但这类模型参数量大、计算资源消耗高,直接部署到实际业务系统中会面临响应慢、成本高的问题。因此模型压缩与部署是落地前必须完成的步骤。
一、诊断结论信息抽取任务简介
诊断结论信息抽取的目标是从非结构化的病历文本中,提取出疾病名称、诊断依据、严重程度、治疗方案等结构化信息。比如输入一段病历描述:“患者因反复头晕3天入院,查体血压160/95mmHg,诊断为原发性高血压2级,高危组,给予硝苯地平控释片口服治疗”,模型需要输出对应的结构化字段。
我们之前基于此任务微调的模型是基于BERT-base架构的,参数量约1.1亿,在测试集上的F1值达到0.91,但推理单条文本需要约120ms,部署到并发量较高的问诊系统中时,延迟和GPU资源占用都难以满足要求。
二、常用模型压缩方法
针对我们微调后的诊断结论抽取模型,常用的压缩方法有以下几种,我们可以根据实际需求组合使用:
- 知识蒸馏:用已经训练好的大模型(教师模型)指导小模型(学生模型)学习,让学生模型模拟教师模型的输出分布,在减小参数量的同时尽量保留效果。
- 量化:将模型参数从32位浮点数(FP32)转换为16位浮点数(FP16)或者8位整数(INT8),减少内存占用和计算量。
- 剪枝:移除模型中不重要的权重或者神经元,减少参数量和计算量,分为结构化剪枝和非结构化剪枝两类。
2.1 知识蒸馏实践
这里我们选择更小的BERT-tiny作为学生模型,它只有2层Transformer、隐藏层维度768,参数量仅为教师模型的1/10左右。蒸馏的核心思路是让学生模型不仅学习真实标签的交叉熵损失,还要学习教师模型的软标签损失。
下面是蒸馏过程的核心代码示例,使用PyTorch框架实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义蒸馏损失,结合硬标签损失和软标签损失
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.temperature = temperature # 软标签的温度参数
self.alpha = alpha # 软标签损失的权重
self.hard_loss = nn.CrossEntropyLoss() # 硬标签损失,对应真实标注
self.soft_loss = nn.KLDivLoss(reduction="batchmean") # 软标签损失,对应教师模型输出
def forward(self, student_logits, teacher_logits, labels):
# 计算硬标签损失
loss_hard = self.hard_loss(student_logits, labels)
# 计算软标签损失,对教师和学生的logits做温度缩放后计算KL散度
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
loss_soft = self.soft_loss(student_soft, teacher_soft) * (self.temperature ** 2)
# 总损失加权求和
total_loss = self.alpha * loss_soft + (1 - self.alpha) * loss_hard
return total_loss
# 蒸馏训练的单步逻辑示例
def train_step(student_model, teacher_model, batch, optimizer, criterion, device):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
# 教师模型推理,不计算梯度
with torch.no_grad():
teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits
# 学生模型推理
student_logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits
# 计算蒸馏损失
loss = criterion(student_logits, teacher_logits, labels)
# 反向传播更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()经过蒸馏训练后,学生模型在测试集上的F1值仅下降了0.03,达到0.88,而推理单条文本的时间降到了35ms左右,效果提升明显。
2.2 量化与剪枝实践
量化我们采用PyTorch自带的动态量化方法,不需要额外校准数据,直接对模型权重做INT8转换:
import torch
from transformers import BertForTokenClassification
# 加载蒸馏后的学生模型
student_model = BertForTokenClassification.from_pretrained("./distilled_bert_tiny_ner")
# 动态量化模型,将线性层权重转为INT8
quantized_model = torch.quantization.quantize_dynamic(
student_model,
{torch.nn.Linear}, # 对线性层做量化
dtype=torch.qint8
)
# 保存量化后的模型
torch.save(quantized_model.state_dict(), "./quantized_bert_tiny_ner.pth")量化后模型体积从400MB降到了100MB左右,推理速度进一步提升到20ms/条。如果还需要更极致的压缩,可以对模型做结构化剪枝,移除注意力头和全连接层的冗余神经元,不过剪枝后通常需要再做少量epoch的微调来恢复效果。
三、模型部署方案
压缩后的模型需要部署到服务中供业务调用,这里介绍两种常用的部署方式,我们可以根据实际场景选择。
3.1 基于FastAPI的轻量部署
如果并发量不高,用FastAPI构建HTTP服务是最简单的方式,代码结构清晰,容易维护。下面是部署服务的核心代码:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import BertTokenizer, BertForTokenClassification
from seqeval.metrics import classification_report
import numpy as np
# 定义请求体的格式
class DiagnosisInput(BaseModel):
text: str
# 初始化FastAPI应用
app = FastAPI(title="诊断结论信息抽取服务")
# 加载tokenizer和量化后的模型
tokenizer = BertTokenizer.from_pretrained("./bert_tiny_tokenizer")
model = BertForTokenClassification.from_pretrained("./distilled_bert_tiny_ner")
model.load_state_dict(torch.load("./quantized_bert_tiny_ner.pth", map_location="cpu"))
model.eval()
# 标签映射,根据实际训练的标签调整
label_map = {0: "O", 1: "B-疾病", 2: "I-疾病", 3: "B-严重程度", 4: "I-严重程度", 5: "B-治疗方案", 6: "I-治疗方案"}
# 定义推理接口
@app.post("/extract")
def extract_diagnosis(input_data: DiagnosisInput):
try:
# 文本编码
inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# 模型推理
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=2).squeeze().numpy()
# 转换预测结果为标签
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().numpy())
pred_labels = [label_map[p] for p in predictions]
# 整理抽取结果,这里简化为输出标签序列,实际可以根据需求解析为结构化字段
result = []
for token, label in zip(tokens, pred_labels):
if token not in ["[CLS]", "[SEP]", "[PAD]"]:
result.append({"token": token, "label": label})
return {"status": "success", "result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)启动服务后,我们可以通过POST请求访问 http://127.0.0.1:8000/extract 接口,传入病历文本即可得到抽取结果,单接口响应时间稳定在30ms以内。
3.2 基于TensorRT的高性能部署
如果业务并发量很高,需要更低的延迟和更高的吞吐量,可以把模型转换为TensorRT格式。TensorRT是NVIDIA推出的推理加速引擎,支持FP16、INT8量化,还能做层融合、内核优化等操作,进一步提升推理速度。
转换步骤大致是:先把PyTorch模型转换为ONNX格式,再用TensorRT的ONNX解析器转换为引擎文件。核心转换代码如下:
import torch
from transformers import BertTokenizer, BertForTokenClassification
# 1. 导出为ONNX格式
model = BertForTokenClassification.from_pretrained("./distilled_bert_tiny_ner")
model.eval()
tokenizer = BertTokenizer.from_pretrained("./bert_tiny_tokenizer")
# 构造虚拟输入,匹配实际输入的shape
dummy_input = tokenizer("示例诊断文本", return_tensors="pt")
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
"diagnosis_model.onnx",
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"logits": {0: "batch_size", 1: "seq_len"}
},
opset_version=14
)
# 2. 转换为TensorRT引擎(需要安装TensorRT环境)
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# 解析ONNX文件
with open("diagnosis_model.onnx", "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise RuntimeError("ONNX解析失败")
# 配置构建参数,开启FP16推理
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB工作空间
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# 构建引擎并保存
serialized_engine = builder.build_serialized_network(network, config)
with open("diagnosis_model.trt", "wb") as f:
f.write(serialized_engine)转换后的TensorRT引擎推理速度可以达到10ms/条以内,吞吐量相比原生PyTorch部署提升3倍以上,适合高并发的医疗业务场景。
四、效果验证与注意事项
部署完成后需要验证几个核心指标:
- 效果指标:在测试集上的F1值下降不超过0.05,满足业务精度要求。
- 性能指标:单请求延迟不超过50ms,单GPU卡每秒能处理的请求数(QPS)不低于200。
- 稳定性:连续运行24小时无内存泄漏、无请求超时问题。
另外需要注意,医疗场景下的模型部署需要符合相关监管要求,抽取的结果仅作为辅助参考,不能直接作为临床诊疗的最终依据,同时要做好患者数据的脱敏和隐私保护,避免敏感信息泄露。