代码地址:https://github.com/china-ai-law-challenge/CAIL2021/
/*
* @Author: Yue.Fan
* @Date: 2022-03-23 11:35:28
* @Last Modified by: Yue.Fan
* @Last Modified time: 2022-03-23 11:35:28
*/
import logging
from dataclasses import dataclass
from typing import List, Dict
import json
from tqdm import tqdm
from transformers import PreTrainedTokenizer, BasicTokenizer, BertTokenizer
from transformers.tokenization_utils import _is_whitespace, _is_punctuation, _is_control
import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset
YES_TOKEN = "[unused1]"
NO_TOKEN = "[unused2]"
class CAILExample:
def __init__(self,
qas_id: str,
question_text: str,
context_text: str,
answer_texts: List[str],
answer_start_indexes: List[int],
is_impossible: bool,
is_yes_no: bool,
is_multi_span: bool,
answers: List,
case_id: str,
case_name: str):
self.qas_id = qas_id # 每一个问题都有一个唯一的id
self.question_text = question_text # 问题文本
self.context_text = context_text # 内容文本
self.answer_texts = answer_texts # 答案列表
self.answer_start_indexes = answer_start_indexes # 答案开始位置列表
self.is_impossible = is_impossible # 是否不存在答案
self.is_yes_no = is_yes_no # 是否是 是否类
self.is_multi_span = is_multi_span # 是否是 多片段类
self.answers = answers # 未经处理的答案列表
self.case_id = case_id # 每一个内容都有一个唯一的案件id
self.case_name = case_name # 案件类型
self.doc_tokens = []
self.char_to_word_offset = []
raw_doc_tokens = customize_tokenizer(context_text, True) # 初步得到token
k = 0
temp_word = ""
# 有的文本中会存在空格、换行等,使用bert会导致答案的偏移
# 因此才会有char_to_word_offset,举个例子
"""
我\n\t爱北京\n\t天安门
['我', '爱', '北', '京', '天', '安', '门']
[0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 6]
这里从0-1之间多了两个0,表明我和爱之间有两个空格
"""
for char in self.context_text:
if _is_whitespace(char):
self.char_to_word_offset.append(k - 1)
continue
else:
temp_word += char
self.char_to_word_offset.append(k)
if temp_word.lower() == raw_doc_tokens[k]:
self.doc_tokens.append(temp_word)
temp_word = ""
k += 1
assert k == len(raw_doc_tokens)
if answer_texts is not None: # if for training
start_positions = []
end_positions = []
if not is_impossible and not is_yes_no:
for i in range(len(answer_texts)):
# 这里还是以上面例子为例
# 北京在原始文本的开始位置是4
answer_offset = context_text.index(answer_texts[i]) # 这里直接index不太好吧
# answer_offset = answer_start_indexes[i]
answer_length = len(answer_texts[i])
start_position = self.char_to_word_offset[answer_offset] # 在doc_tokens里面的位置就是
end_position = self.char_to_word_offset[answer_offset + answer_length - 1]
start_positions.append(start_position) # 真正的开始位置
end_positions.append(end_position) # 真正的结束位置
else:
start_positions.append(-1) # 不存在答案就设置为-1
end_positions.append(-1) # 不存在答案就设置为-1
self.start_positions = start_positions
self.end_positions = end_positions
def __repr__(self):
string = ""
for key, value in self.__dict__.items():
string += f"{key}: {value}\n"
# return f""
return string
@dataclass
class CAILFeature:
input_ids: List[int]
attention_mask: List[int]
token_type_ids: List[int]
cls_index: int
p_mask: List
example_index: int
unique_id: int
paragraph_len: int
token_is_max_context: object
tokens: List
token_to_orig_map: Dict
start_positions: List[int]
end_positions: List[int]
is_impossible: bool
@dataclass
class CAILResult:
unique_id: int
start_logits: torch.Tensor
end_logits: torch.Tensor
def read_examples(file: str, is_training: bool) -> List[CAILExample]:
example_list = []
with open(file, "r", encoding="utf-8") as file:
original_data = json.load(file)["data"]
for entry in tqdm(original_data):
case_id = entry["caseid"]
for paragraph in entry["paragraphs"]:
context = paragraph["context"]
case_name = paragraph["casename"]
for qa in paragraph["qas"]:
question = qa["question"]
qas_id = qa["id"]
answer_texts = None
answer_starts = None
is_impossible = None
is_yes_no = None
is_multi_span = None
all_answers = None
# cail2021包含以下击中答案:单片段、是否类和拒答类的问题类型,相较于之前的,
# 额外引入了多片段类型,即答案是由多个片段组合而成
if is_training:
all_answers = qa["answers"]
# all_answers为[],说明没有答案
if len(all_answers) == 0:
answer = []
else:
# 否则取第0个
answer = all_answers[0]
# a little difference between 19 and 21 data.
# 如果是一个字典的话将其用列表包裹
if type(answer) == dict:
answer = [answer]
# 不存在答案就初始化答案的文本为"",答案起始位置设置为-1
if len(answer) == 0: # NO Answer
answer_texts = [""]
answer_starts = [-1]
else:
# 否则的话这里整合答案
answer_texts = []
answer_starts = []
# 如果是单个span,就是一个
# 否则的话就遍历一下
for a in answer:
answer_texts.append(a["text"])
answer_starts.append(a["answer_start"])
# Judge YES or NO
# 判断是否是 是还是否类型的,并进行设置
if len(answer_texts) == 1 and answer_starts[0] == -1 and (
answer_texts[0] == "YES" or answer_texts[0] == "NO"):
is_yes_no = True
else:
is_yes_no = False
# Judge Multi Span
# 判断是否是由多个span构成的答案
if len(answer_texts) > 1:
is_multi_span = True
else:
is_multi_span = False
# Judge No Answer
# 如果不存在答案的话用以下的进行标识
if len(answer_texts) == 1 and answer_texts[0] == "":
is_impossible = True
else:
is_impossible = False
example = CAILExample(
qas_id=qas_id,
question_text=question,
context_text=context,
answer_texts=answer_texts,
answer_start_indexes=answer_starts,
is_impossible=is_impossible,
is_yes_no=is_yes_no,
is_multi_span=is_multi_span,
answers=all_answers,
case_id=case_id,
case_name=case_name
)
# Discard possible bad example
if is_training and example.answer_start_indexes[0] >= 0:
for i in range(len(example.answer_texts)):
actual_text = "".join(
example.doc_tokens[example.start_positions[i]: (example.end_positions[i] + 1)])
cleaned_answer_text = "".join(whitespace_tokenize(example.answer_texts[i]))
if actual_text.find(cleaned_answer_text) == -1:
logging.info(f"Could not find answer: {actual_text} vs. {cleaned_answer_text}")
continue
example_list.append(example)
return example_list
def convert_examples_to_features(example_list: List[CAILExample], tokenizer: PreTrainedTokenizer, args,
is_training: bool) -> List[CAILFeature]:
# Validate there are no duplicate ids in example_list
qas_id_set = set()
for example in example_list:
if example.qas_id in qas_id_set:
raise Exception("Duplicate qas_id!")
else:
qas_id_set.add(example.qas_id)
feature_list = []
unique_id = 0
example_index = 0
i = 0
for example in tqdm(example_list):
i += 1
# if i % 100 == 0:
# print(i)
current_example_features = convert_single_example_to_features(example, tokenizer, args.max_seq_length,
args.max_query_length, args.doc_stride,
is_training)
for feature in current_example_features:
feature.example_index = example_index
feature.unique_id = unique_id
unique_id += 1
example_index += 1
feature_list.extend(current_example_features)
return feature_list
def convert_single_example_to_features(example: CAILExample, tokenizer: PreTrainedTokenizer,
max_seq_length, max_query_length, doc_stride, is_training) -> List[CAILFeature]:
"""
Transfer original text to sequence which can be accepted by ELECTRA
Format: [CLS] YES_TOKEN NO_TOKEN question [SEP] context [SEP]
"""
features = []
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
"""
['我', '爱', '北', '京', '15826458891', '天', '安', '门']
orig_to_tok_index:[0, 1, 2, 3, 4, 9, 10, 11]
tok_to_orig_index:[0, 1, 2, 3, 4, 4, 4, 4, 4, 5, 6, 7]
all_doc_tokens:['我', '爱', '北', '京', '158', '##26', '##45', '##88', '##91', '天', '安', '门']
"""
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token) # 这里进一步对token尽可能进行切分
for sub_token in sub_tokens:
tok_to_orig_index.append(i) # 每一个sub_token对应的i是相同的
all_doc_tokens.append(sub_token)
if is_training:
if example.is_impossible or example.answer_start_indexes[0] == -1:
start_positions = [-1]
end_positions = [-1]
else:
start_positions = []
end_positions = []
# 以下是对tokenize化之后校准position
for i in range(len(example.start_positions)):
start_position = orig_to_tok_index[example.start_positions[i]]
if example.end_positions[i] < len(example.doc_tokens) - 1:
end_position = orig_to_tok_index[example.end_positions[i] + 1] - 1
else:
end_position = len(all_doc_tokens) - 1
(start_position, end_position) = _improve_answer_span(
all_doc_tokens, start_position, end_position, tokenizer, example.answer_texts[i]
)
start_positions.append(start_position)
end_positions.append(end_position)
else:
start_positions = None
end_positions = None
query_tokens = tokenizer.tokenize(example.question_text)
query_tokens = [YES_TOKEN, NO_TOKEN] + query_tokens # 是否类和问题拼接
truncated_query = tokenizer.encode(query_tokens, add_special_tokens=False, max_length=max_query_length,
truncation=True)
sequence_pair_added_tokens = tokenizer.num_special_tokens_to_add(pair=True)
assert sequence_pair_added_tokens == 3
added_tokens_num_before_second_sequence = tokenizer.num_special_tokens_to_add(pair=False)
assert added_tokens_num_before_second_sequence == 2
span_doc_tokens = all_doc_tokens
spans = []
# print("query_tokens:", query_tokens)
# print("all_doc_tokens:", all_doc_tokens)
# print("".join(all_doc_tokens))
# print("start_positions:", start_positions)
# print("end_positions:", end_positions)
# 这里使用滑动窗口法
while len(spans) * doc_stride < len(all_doc_tokens):
# print(max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,)
# 以步长为doc_stride进行滑窗
encoded_dict = tokenizer.encode_plus(
truncated_query,
span_doc_tokens,
max_length=max_seq_length,
return_overflowing_tokens=True,
padding="max_length",
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
truncation="only_second",
return_token_type_ids=True
)
# print(span_doc_tokens)
# print("stride:", max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens)
# print(tokenizer.convert_ids_to_tokens(encoded_dict['input_ids']))
# print(len(encoded_dict['input_ids']))
# print(tokenizer.convert_ids_to_tokens(encoded_dict['overflowing_tokens']))
# 句子的真实长度
paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
)
# 不包含[PAD]的token_ids
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
else:
non_padded_ids = encoded_dict["input_ids"]
# 重新将ids转换为tokens
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
token_to_orig_map = {}
token_to_orig_map[0] = -1
token_to_orig_map[1] = -1
token_to_orig_map[2] = -1
token_is_max_context = {0: True, 1: True, 2: True}
for i in range(paragraph_len):
# token在输入[CLS]query[SEP]context[SEP]里面的索引
index = len(truncated_query) + added_tokens_num_before_second_sequence + i
# tok_to_orig_index是token在context里面的索引
# spans的长度表明当前总共有几个片段
# token_to_orig_map是将index映射到真实的i上
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
# print(token_to_orig_map)
encoded_dict["paragraph_len"] = paragraph_len
encoded_dict["tokens"] = tokens
encoded_dict["token_to_orig_map"] = token_to_orig_map
encoded_dict["truncated_query_with_special_tokens_length"] = len(
truncated_query) + added_tokens_num_before_second_sequence
encoded_dict["token_is_max_context"] = token_is_max_context
encoded_dict["start"] = len(spans) * doc_stride # 文本的起始索引
encoded_dict["length"] = paragraph_len
# 这里将是否类的标记token_type_ids设置为1,为什么?
encoded_dict["token_type_ids"][1] = 1
encoded_dict["token_type_ids"][2] = 1
# print(encoded_dict["token_type_ids"])
spans.append(encoded_dict)
if "overflowing_tokens" not in encoded_dict or len(encoded_dict["overflowing_tokens"]) == 0:
break
else:
span_doc_tokens = encoded_dict["overflowing_tokens"]
for doc_span_index in range(len(spans)):
for j in range(spans[doc_span_index]["paragraph_len"]):
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
index = spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
for span in spans:
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# p_mask是将问题和SEP对应位置设置为1,其余位置设置为0
p_mask = np.array(span["token_type_ids"])
p_mask = np.minimum(p_mask, 1)
p_mask = 1 - p_mask
p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
p_mask[cls_index] = 0
p_mask[1] = 0
p_mask[2] = 0
current_start_positions = None
current_end_positions = None
span_is_impossible = None
if is_training:
current_start_positions = [0 for i in range(max_seq_length)]
current_end_positions = [0 for i in range(max_seq_length)]
doc_start = span["start"]
doc_end = span["start"] + span["length"] - 1 # 文本的截止索引
doc_offset = len(truncated_query) + added_tokens_num_before_second_sequence # 偏移量
for i in range(len(start_positions)):
start_position = start_positions[i]
end_position = end_positions[i]
# 这里重新整合start_position和end_position
if start_position >= doc_start and end_position Dataset:
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
all_example_indexes = torch.tensor([f.example_index for f in features], dtype=torch.long)
all_feature_indexes = torch.arange(all_input_ids.size(0), dtype=torch.long)
if is_training:
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
all_start_labels = torch.tensor([f.start_positions for f in features], dtype=torch.float)
all_end_labels = torch.tensor([f.end_positions for f in features], dtype=torch.float)
dataset = TensorDataset(
all_input_ids,
all_attention_masks,
all_token_type_ids,
all_start_labels,
all_end_labels,
all_cls_index,
all_p_mask,
all_is_impossible,
all_example_indexes,
all_feature_indexes
)
else:
dataset = TensorDataset(
all_input_ids,
all_attention_masks,
all_token_type_ids,
all_cls_index,
all_p_mask,
all_example_indexes,
all_feature_indexes
)
return dataset
def _is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
return True
return False
def _new_check_is_max_context(doc_spans, cur_span_index, position):
"""
Check if this is the 'max context' doc span for the token.
"""
# if len(doc_spans) == 1:
# return True
best_score = None
best_span_index = None
for (span_index, doc_span) in enumerate(doc_spans):
end = doc_span["start"] + doc_span["length"] - 1
if position < doc_span["start"]:
continue
if position > end:
continue
num_left_context = position - doc_span["start"]
num_right_context = end - position
score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
if best_score is None or score > best_score:
best_score = score
best_span_index = span_index
return cur_span_index == best_span_index
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
"""
Returns tokenized answer spans that better match the annotated answer.
"""
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
if text_span == tok_answer_text:
return (new_start, new_end)
return (input_start, input_end)
def customize_tokenizer(text: str, do_lower_case=True) -> List[str]:
temp_x = ""
for char in text:
# 在一些特殊字符左右插入两个空格
if _is_chinese_char(ord(char)) or _is_punctuation(char) or _is_whitespace(char) or _is_control(char):
temp_x += " " + char + " "
else:
temp_x += char
# 是否将英文大写转换为小写
if do_lower_case:
temp_x = temp_x.lower()
return temp_x.split() # 这里会使用空格进行切分
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp = 0x3400 and cp = 0x20000 and cp = 0x2A700 and cp = 0x2B740 and cp = 0x2B820 and cp = 0xF900 and cp = 0x2F800 and cp
我们需要注意的是总过经过了三个重整的阶段:
- 第一阶段:先初步将文本进行token化,这一步是去除掉文本中的一些特殊符号及空格等,因此要对答案的起始位置进行校准。
- 第二阶段:这一步利用tokenizer对每一个字(词)进行token化,由于是wordpiece,会影响句子的长度以及答案,因此也要重新进行校准。
- 第三阶段:这一步是要整合问题和文本,同时采用滑动窗口法,因此也要重新校准答案在文本中的位置。
<input_ids: 0 4 7 487 [101, 1, 2, 7342, 12124, 1762, 2398, 2128, 6568, 7372, 2551, 1146, 1062, 1385, 1905, 2832, 749, 784, 720, 924, 8043, 102, 9595, 119, 125, 1039, 132, 124, 1161, 808, 510, 3342, 10871, 3118, 802, 2526, 2360, 6589, 11960, 8129, 2824, 2857, 3315, 3428, 4638, 6401, 6390, 4500, 752, 2141, 680, 4415, 4507, 131, 8138, 2399, 8110, 3299, 3189, 117, 6206, 3724, 711, 1071, 704, 1744, 1093, 689, 7213, 6121, 5500, 819, 3300, 7361, 4689, 5852, 6956, 113, 809, 678, 5042, 4917, 114, 8416, 9086, 6587, 3621, 2990, 897, 702, 782, 3867, 928, 6395, 2970, 1358, 2400, 5041, 1355, 1296, 5356, 1384, 8752, 9723, 9131, 8756, 11906, 9446, 8311, 8152, 5373, 5287, 3175, 2466, 3309, 3680, 9649, 8158, 7313, 5632, 1394, 1398, 7555, 3123, 722, 6629, 5635, 3926, 985, 1059, 2622, 3632, 6421, 5276, 2137, 2870, 3612, 818, 862, 671, 6809, 1168, 8188, 1921, 6228, 3125, 4495, 898, 2945, 2190, 6822, 6608, 794, 2496, 2458, 1993, 6631, 6814, 8114, 793, 3313, 1403, 2495, 6820, 1156, 6824, 7444, 2213, 1825, 3144, 6369, 5050, 2902, 1283, 3403, 1114, 7032, 1400, 5815, 2533, 2207, 7583, 1066, 6854, 2382, 3833, 955, 8216, 5023, 1728, 2130, 2252, 721, 1218, 754, 8119, 123, 8132, 4509, 6435, 5164, 102] attention_mask: [1, 1] token_type_ids: [0, 0, cls_index: p_mask: example_index: unique_id: paragraph_len: token_is_max_context: {0: true, 1: 2: 24: false, 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62: 63: 64: 65: 66: 67: 68: 69: 70: 71: 72: 73: 74: 75: 76: 77: 78: 79: 80: 81: 82: 83: 84: 85: 86: 87: 88: 89: 90: 91: 92: 93: 94: 95: 96: 97: 98: 99: 100: 101: 102: 103: 104: 105: 106: 107: 108: 109: 110: 111: 112: 113: 114: 115: 116: 117: 118: 119: 120: 121: 122: 123: 124: 125: 126: 127: 128: 129: 130: 131: 132: 133: 134: 135: 136: 137: 138: 139: 140: 141: 142: 143: 144: 145: 146: 147: 148: 149: 150: 151: 152: 153: 154: 155: 156: 157: 158: 159: 160: 161: 162: 163: 164: 165: 166: 167: 168: 169: 170: 171: 172: 173: 174: 175: 176: 177: 178: 179: 180: 181: 182: 183: 184: 185: 186: 187: 188: 189: 190: 191: 192: 193: 194: 195: 196: 197: 198: 199: 200: 201: 202: 203: 204: 205: 206: 207: 208: 209: 210: 211: 212: 213: 214: 215: 216: 217: 218: 219: 220: 221: 222: 223: 224: 225: 226: 227: 228: 229: 230: 231: 232: 233: 234: 235: 236: 237: 238: 239: 240: 241: 242: 243: 244: 245: 246: 247: 248: 249: 250: 251: 252: 253: 254: 255: 256: 257: 258: 259: 260: 261: 262: 263: 264: 265: 266: 267: 268: 269: 270: 271: 272: 273: 274: 275: 276: 277: 278: 279: 280: 281: 282: 283: 284: 285: 286: 287: 288: 289: 290: 291: 292: 293: 294: 295: 296: 297: 298: 299: 300: 301: 302: 303: 304: 305: 306: 307: 308: 309: 310: 311: 312: 313: 314: 315: 316: 317: 318: 319: 320: 321: 322: 323: 324: 325: 326: 327: 328: 329: 330: 331: 332: 333: 334: 335: 336: 337: 338: 339: 340: 341: 342: 343: 344: 345: 346: 347: 348: 349: 350: 351: 352: 353: 354: 355: 356: 357: 358: 359: 360: 361: 362: 363: 364: 365: 366: 367: 368: 369: 370: 371: 372: 373: 374: 375: 376: 377: 378: 379: 380: 381: 382: 383: 384: 385: 386: 387: 388: 389: 390: 391: 392: 393: 394: 395: 396: 397: 398: 399: 400: 401: 402: 403: 404: 405: 406: 407: 408: 409: 410: 411: 412: 413: 414: 415: 416: 417: 418: 419: 420: 421: 422: 423: 424: 425: 426: 427: 428: 429: 430: 431: 432: 433: 434: 435: 436: 437: 438: 439: 440: 441: 442: 443: 444: 445: 446: 447: 448: 449: 450: 451: 452: 453: 454: 455: 456: 457: 458: 459: 460: 461: 462: 463: 464: 465: 466: 467: 468: 469: 470: 471: 472: 473: 474: 475: 476: 477: 478: 479: 480: 481: 482: 483: 484: 485: 486: 487: 488: 489: 490: 491: 492: 493: 494: 495: 496: 497: 498: 499: 500: 501: 502: 503: 504: 505: 506: 507: 508: 509: 510: false} tokens: ['[cls]', '[unused1]', '[unused2]', '阮', 'x4', '在', '平', '安', '财', '险', '徽', '分', '公', '司', '处', '投', '了', '什', '么', '保', '?', '[sep]', '##92', '.', '4', '元', ';', '3', '判', '令', '、', '杨', 'x5', '支', '付', '律', '师', '费', '690', '##0', '承', '担', '本', '案', '的', '诉', '讼', '用', '事', '实', '与', '理', '由', ':', '2013', '年', '12', '月', '日', ',', '要', '求', '为', '其', '中', '国', '农', '业', '银', '行', '股', '份', '有', '限', '省', '营', '部', '(', '以', '下', '简', '称', ')', '94', '##000', '贷', '款', '提', '供', '个', '人', '消', '信', '证', '接', '受', '并', '签', '发', '单', '编', '号', '125', '##94', '##07', '##26', '##010', '##87', '##10', '##3', '缴', '纳', '方', '式', '期', '每', '178', '##6', '间', '自', '合', '同', '项', '放', '之', '起', '至', '清', '偿', '全', '息', '止', '该', '约', '定', '拖', '欠', '任', '何', '一', '达', '到', '80', '天', '视', '故', '生', '依', '据', '对', '进', '赔', '从', '当', '开', '始', '超', '过', '30', '仍', '未', '向', '归', '还', '则', '违', '需', '尚', '基', '数', '计', '算', '按', '千', '标', '准', '金', '后', '获', '得', '小', '额', '共', '途', '常', '活', '借', '36', '等', '因', '完', '履', '义', '务', '于', '2015', '2', '25', '申', '请', '索', '[sep]'] token_to_orig_map: -1, 126, 127, 128, 129, 130, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598} start_positions: 0] end_positions: is_impossible: false>
</input_ids:>
最后需要注意的是:怎么将是否类、单个片段类和多个片段类进行统一的标识,以及输入中的token_type_ids是怎么进行设置的。
最有有一个处理文件将处理后的数据保存下来,避免每次重复进行处理,消耗时间:
"""
数据处理相关代码
"""
import argparse
import json
from transformers import PreTrainedTokenizer, BertTokenizer
from data_process_utils import *
import gzip
import pickle
import os
from os.path import join
import logging
def convert_and_write(args, tokenizer: PreTrainedTokenizer, file, examples_fn, features_fn, is_training):
logging.info(f"Reading examples from :{file} ...")
example_list = read_examples(file, is_training=is_training)
logging.info(f"Total examples:{len(example_list)}")
logging.info(f"Start converting examples to features.")
feature_list = convert_examples_to_features(example_list, tokenizer, args, is_training)
logging.info(f"Total features:{len(feature_list)}")
logging.info(f"Converting complete, writing examples and features to file.")
with gzip.open(join(args.output_path, examples_fn), "wb") as file:
pickle.dump(example_list, file)
with gzip.open(join(args.output_path, features_fn), "wb") as file:
pickle.dump(feature_list, file)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
required=True,
help="The file to be processed."
)
parser.add_argument(
"--for_training",
action="store_true",
help="Process for training or not."
)
parser.add_argument(
"--output_prefix",
type=str,
required=True,
help="The prefix of output file's name."
)
parser.add_argument(
"--do_lower_case",
action="store_true",
help="Set this flag if you are using an uncased model."
)
parser.add_argument(
"--tokenizer_path",
type=str,
required=True,
help="Path to tokenizer which will be used to tokenize text.(ElectraTokenizer)"
)
parser.add_argument(
"--max_seq_length",
default=512,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. "
"Longer will be truncated, and shorter will be padded."
)
parser.add_argument(
"--max_query_length",
default=64,
type=int,
help="The maximum number of tokens for the question. Questions longer will be truncated to the length."
)
parser.add_argument(
"--doc_stride",
default=128,
type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks."
)
parser.add_argument(
"--output_path",
default="./processed_data/",
type=str,
help="Output path of the constructed examples and features."
)
args = parser.parse_args()
args.max_query_length += 2 # position for token yes and no
logging.basicConfig(
format="%(asctime)s - %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logging.info("All input parameters:")
print(json.dumps(vars(args), sort_keys=False, indent=2))
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
convert_and_write(args, tokenizer, args.input_file, args.output_prefix + "_examples.pkl.gz",
args.output_prefix + "_features.pkl.gz", args.for_training)
if __name__ == "__main__":
main()
运行指令:
python data_process.py --input_file data_sample/cail2021_mrc_small.json --output_prefix cail2021_mrc_small --tokenizer_path model_hub/chinese-bert-wwm-ext --max_seq_length 512 --max_query_length 64 --doc_stride 128 --do_lower_case --for_training
Original: https://www.cnblogs.com/xiximayou/p/16359133.html
Author: 西西嘛呦
Title: CAIL2021-阅读理解任务-数据预处理模块(二)
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/554114/
转载文章受原作者版权保护。转载请注明原作者出处!