
在自然语言处理任务中,尤其是命名实体识别(ner)或令牌分类等,处理长度远超模型最大输入序列(如512个token)的文档是一个常见挑战。为了解决这个问题,hugging face transformers库引入了“滑动窗口”(sliding window)策略,通过truncation、max_length、stride和return_overflowing_tokens等参数来实现。然而,这些参数的错误配置可能导致模型在处理长文本时出现预测中断或不完整的问题。
核心问题在于,许多用户尝试在AutoTokenizer.from_pretrained()方法中设置stride等参数,但这些参数并非用于加载分词器配置,而是用于实际执行分词操作时的运行时参数。
在深入探讨解决方案之前,我们首先明确几个关键参数的含义:
当直接使用AutoTokenizer对文本进行分词时,stride、max_length、truncation和return_overflowing_tokens等参数必须在分词器的__call__方法中传递,而不是在from_pretrained方法中。from_pretrained仅用于加载预训练分词器的配置和词汇表。
以下是一个示例,展示了如何正确应用这些参数:
from transformers import AutoModelForTokenClassification, AutoTokenizer
# 假设我们有一个预训练模型ID
model_id = 'Davlan/distilbert-base-multilingual-cased-ner-hrl'
# 加载分词器,这里不设置stride等参数
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 示例文本,模拟长文档
sample_text = "这是一个非常长的示例文本,我们需要使用滑动窗口技术来对其进行处理和分析。"*200
# 错误用法:在from_pretrained中设置的参数无效
# tokenizer_wrong = AutoTokenizer.from_pretrained(model_id, stride=3, return_overflowing_tokens=True, max_length=10, truncation=True)
# print(f"错误用法分词结果长度 (不应用滑动窗口): {len(tokenizer_wrong(sample_text).input_ids)}")
# 正确用法:在__call__方法中传递参数
tokenized_output = tokenizer(
sample_text,
max_length=10, # 示例中的短max_length,实际应用中通常为512
truncation=True,
stride=3, # 示例中的短stride,实际应用中通常为128
return_overflowing_tokens=True
)
print(f"正确用法分词结果长度 (应用滑动窗口): {len(tokenized_output.input_ids)}")
# 预期输出将是多个批次,因为滑动窗口被应用在上述代码中,tokenizer(sample_text, ...) 才是实际执行分词操作的地方,因此所有与分词行为相关的参数都应在此处传递。
对于Hugging Face pipeline,特别是token-classification管道,stride和其他相关参数可以直接在其构造函数中传递。pipeline会在内部处理这些参数,将其传递给其使用的分词器实例。
以下是使用pipeline进行长文本令牌分类的正确示例:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
# 假设我们有一个预训练模型ID
model_id = 'Davlan/distilbert-base-multilingual-cased-ner-hrl'
# 正确用法:在pipeline构造函数中传递stride等参数
# pipeline会自动加载模型和分词器,并将stride等参数传递给分词器
ner_pipeline = pipeline(
"token-classification",
model=model_id,
stride=128, # 设置滑动窗口的步长
aggregation_strategy="first", # 对于重叠区域的实体,采用第一个预测结果
tokenizer=model_id # 可以显式指定tokenizer,或让pipeline自动加载
)
# 示例文本,模拟一个非常长的文档
long_sample_text = "Hi my name is cronoik and I live in Germany. "*3000
# 使用pipeline进行预测
predictions = ner_pipeline(long_sample_text)
print(f"预测结果数量: {len(predictions)}")
print("前5个预测结果:")
for i, pred in enumerate(predictions[:5]):
print(pred)在这个例子中,stride=128被直接传递给了pipeline的构造函数。pipeline会负责在内部调用分词器时应用这个stride参数,从而确保整个长文本都能被处理,并且实体识别不会在文本中途停止。aggregation_strategy="first"则指定了当多个重叠窗口都对同一区域的实体进行预测时,如何合并这些预测结果。
正确配置stride和相关参数是利用Transformer模型处理长文本的关键。通过将这些参数传递给tokenizer.__call__方法或pipeline的构造函数,开发者可以有效地实现滑动窗口机制,确保模型对整个长文档进行全面、准确的分析,避免预测中断的问题。理解这些参数的正确作用域和使用方式,是构建鲁棒长文本处理系统的基础。
以上就是Transformer模型处理长文本:stride参数的正确应用与实践的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号