CAIL2021-阅读理解任务-数据预处理模块(二)

代码地址:https://github.com/china-ai-law-challenge/CAIL2021/

/*
 * @Author: Yue.Fan
 * @Date: 2022-03-23 11:35:28
 * @Last Modified by:   Yue.Fan
 * @Last Modified time: 2022-03-23 11:35:28
 */
import logging
from dataclasses import dataclass
from typing import List, Dict
import json
from tqdm import tqdm
from transformers import PreTrainedTokenizer, BasicTokenizer, BertTokenizer
from transformers.tokenization_utils import _is_whitespace, _is_punctuation, _is_control
import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset

YES_TOKEN = "[unused1]"
NO_TOKEN = "[unused2]"

class CAILExample:
    def __init__(self,
                 qas_id: str,
                 question_text: str,
                 context_text: str,
                 answer_texts: List[str],
                 answer_start_indexes: List[int],
                 is_impossible: bool,
                 is_yes_no: bool,
                 is_multi_span: bool,
                 answers: List,
                 case_id: str,
                 case_name: str):
        self.qas_id = qas_id  # 每一个问题都有一个唯一的id
        self.question_text = question_text  # 问题文本
        self.context_text = context_text  # 内容文本
        self.answer_texts = answer_texts  # 答案列表
        self.answer_start_indexes = answer_start_indexes # 答案开始位置列表
        self.is_impossible = is_impossible  # 是否不存在答案
        self.is_yes_no = is_yes_no  # 是否是 是否类
        self.is_multi_span = is_multi_span  # 是否是 多片段类
        self.answers = answers  # 未经处理的答案列表
        self.case_id = case_id  # 每一个内容都有一个唯一的案件id
        self.case_name = case_name  # 案件类型

        self.doc_tokens = []
        self.char_to_word_offset = []

        raw_doc_tokens = customize_tokenizer(context_text, True)  # 初步得到token
        k = 0
        temp_word = ""
        # 有的文本中会存在空格、换行等,使用bert会导致答案的偏移
        # 因此才会有char_to_word_offset,举个例子
"""
        我\n\t爱北京\n\t天安门
        ['我', '爱', '北', '京', '天', '安', '门']
        [0, 0, 0, 1, 2, 3, 3, 3, 4, 5, 6]
        这里从0-1之间多了两个0,表明我和爱之间有两个空格
"""
        for char in self.context_text:
            if _is_whitespace(char):
                self.char_to_word_offset.append(k - 1)
                continue
            else:
                temp_word += char
                self.char_to_word_offset.append(k)
            if temp_word.lower() == raw_doc_tokens[k]:
                self.doc_tokens.append(temp_word)
                temp_word = ""
                k += 1
        assert k == len(raw_doc_tokens)

        if answer_texts is not None:  # if for training
            start_positions = []
            end_positions = []

            if not is_impossible and not is_yes_no:
                for i in range(len(answer_texts)):
                    # 这里还是以上面例子为例
                    # 北京在原始文本的开始位置是4
                    answer_offset = context_text.index(answer_texts[i])  # 这里直接index不太好吧
                    # answer_offset = answer_start_indexes[i]
                    answer_length = len(answer_texts[i])
                    start_position = self.char_to_word_offset[answer_offset]  # 在doc_tokens里面的位置就是
                    end_position = self.char_to_word_offset[answer_offset + answer_length - 1]
                    start_positions.append(start_position)  # 真正的开始位置
                    end_positions.append(end_position)  # 真正的结束位置
            else:
                start_positions.append(-1)  # 不存在答案就设置为-1
                end_positions.append(-1)  # 不存在答案就设置为-1
            self.start_positions = start_positions
            self.end_positions = end_positions

    def __repr__(self):
        string = ""
        for key, value in self.__dict__.items():
            string += f"{key}: {value}\n"
        # return f""
        return string

@dataclass
class CAILFeature:
    input_ids: List[int]
    attention_mask: List[int]
    token_type_ids: List[int]
    cls_index: int
    p_mask: List
    example_index: int
    unique_id: int
    paragraph_len: int
    token_is_max_context: object
    tokens: List
    token_to_orig_map: Dict
    start_positions: List[int]
    end_positions: List[int]
    is_impossible: bool

@dataclass
class CAILResult:
    unique_id: int
    start_logits: torch.Tensor
    end_logits: torch.Tensor

def read_examples(file: str, is_training: bool) -> List[CAILExample]:
    example_list = []
    with open(file, "r", encoding="utf-8") as file:
        original_data = json.load(file)["data"]

    for entry in tqdm(original_data):
        case_id = entry["caseid"]
        for paragraph in entry["paragraphs"]:
            context = paragraph["context"]
            case_name = paragraph["casename"]
            for qa in paragraph["qas"]:
                question = qa["question"]
                qas_id = qa["id"]
                answer_texts = None
                answer_starts = None
                is_impossible = None
                is_yes_no = None
                is_multi_span = None
                all_answers = None
                # cail2021包含以下击中答案:单片段、是否类和拒答类的问题类型,相较于之前的,
                # 额外引入了多片段类型,即答案是由多个片段组合而成
                if is_training:
                    all_answers = qa["answers"]
                    # all_answers为[],说明没有答案
                    if len(all_answers) == 0:
                        answer = []
                    else:
                        # 否则取第0个
                        answer = all_answers[0]
                    # a little difference between 19 and 21 data.

                    # 如果是一个字典的话将其用列表包裹
                    if type(answer) == dict:
                        answer = [answer]
                    # 不存在答案就初始化答案的文本为"",答案起始位置设置为-1
                    if len(answer) == 0:  # NO Answer
                        answer_texts = [""]
                        answer_starts = [-1]
                    else:
                        # 否则的话这里整合答案
                        answer_texts = []
                        answer_starts = []
                        # 如果是单个span,就是一个
                        # 否则的话就遍历一下
                        for a in answer:
                            answer_texts.append(a["text"])
                            answer_starts.append(a["answer_start"])
                    # Judge YES or NO
                    # 判断是否是 是还是否类型的,并进行设置
                    if len(answer_texts) == 1 and answer_starts[0] == -1 and (
                            answer_texts[0] == "YES" or answer_texts[0] == "NO"):
                        is_yes_no = True
                    else:
                        is_yes_no = False
                    # Judge Multi Span
                    # 判断是否是由多个span构成的答案
                    if len(answer_texts) > 1:
                        is_multi_span = True
                    else:
                        is_multi_span = False
                    # Judge No Answer
                    # 如果不存在答案的话用以下的进行标识
                    if len(answer_texts) == 1 and answer_texts[0] == "":
                        is_impossible = True
                    else:
                        is_impossible = False

                example = CAILExample(
                    qas_id=qas_id,
                    question_text=question,
                    context_text=context,
                    answer_texts=answer_texts,
                    answer_start_indexes=answer_starts,
                    is_impossible=is_impossible,
                    is_yes_no=is_yes_no,
                    is_multi_span=is_multi_span,
                    answers=all_answers,
                    case_id=case_id,
                    case_name=case_name
                )
                # Discard possible bad example
                if is_training and example.answer_start_indexes[0] >= 0:
                    for i in range(len(example.answer_texts)):
                        actual_text = "".join(
                            example.doc_tokens[example.start_positions[i]: (example.end_positions[i] + 1)])
                        cleaned_answer_text = "".join(whitespace_tokenize(example.answer_texts[i]))
                        if actual_text.find(cleaned_answer_text) == -1:
                            logging.info(f"Could not find answer: {actual_text} vs. {cleaned_answer_text}")
                            continue
                example_list.append(example)
    return example_list

def convert_examples_to_features(example_list: List[CAILExample], tokenizer: PreTrainedTokenizer, args,
                                 is_training: bool) -> List[CAILFeature]:
    # Validate there are no duplicate ids in example_list
    qas_id_set = set()
    for example in example_list:
        if example.qas_id in qas_id_set:
            raise Exception("Duplicate qas_id!")
        else:
            qas_id_set.add(example.qas_id)

    feature_list = []
    unique_id = 0
    example_index = 0
    i = 0
    for example in tqdm(example_list):
        i += 1
        # if i % 100 == 0:
        #     print(i)
        current_example_features = convert_single_example_to_features(example, tokenizer, args.max_seq_length,
                                                                      args.max_query_length, args.doc_stride,
                                                                      is_training)
        for feature in current_example_features:
            feature.example_index = example_index
            feature.unique_id = unique_id
            unique_id += 1
        example_index += 1
        feature_list.extend(current_example_features)

    return feature_list

def convert_single_example_to_features(example: CAILExample, tokenizer: PreTrainedTokenizer,
                                       max_seq_length, max_query_length, doc_stride, is_training) -> List[CAILFeature]:
"""
    Transfer original text to sequence which can be accepted by ELECTRA
    Format: [CLS] YES_TOKEN NO_TOKEN question [SEP] context [SEP]
"""
    features = []
    tok_to_orig_index = []
    orig_to_tok_index = []
    all_doc_tokens = []
"""
    ['我', '爱', '北', '京', '15826458891', '天', '安', '门']
    orig_to_tok_index:[0, 1, 2, 3, 4, 9, 10, 11]
    tok_to_orig_index:[0, 1, 2, 3, 4, 4, 4, 4, 4, 5, 6, 7]
    all_doc_tokens:['我', '爱', '北', '京', '158', '##26', '##45', '##88', '##91', '天', '安', '门']
"""
    for (i, token) in enumerate(example.doc_tokens):
        orig_to_tok_index.append(len(all_doc_tokens))
        sub_tokens = tokenizer.tokenize(token)  # 这里进一步对token尽可能进行切分
        for sub_token in sub_tokens:
            tok_to_orig_index.append(i)  # 每一个sub_token对应的i是相同的
            all_doc_tokens.append(sub_token)

    if is_training:
        if example.is_impossible or example.answer_start_indexes[0] == -1:
            start_positions = [-1]
            end_positions = [-1]
        else:
            start_positions = []
            end_positions = []
            # 以下是对tokenize化之后校准position
            for i in range(len(example.start_positions)):
                start_position = orig_to_tok_index[example.start_positions[i]]
                if example.end_positions[i] < len(example.doc_tokens) - 1:
                    end_position = orig_to_tok_index[example.end_positions[i] + 1] - 1
                else:
                    end_position = len(all_doc_tokens) - 1
                (start_position, end_position) = _improve_answer_span(
                    all_doc_tokens, start_position, end_position, tokenizer, example.answer_texts[i]
                )
                start_positions.append(start_position)
                end_positions.append(end_position)
    else:
        start_positions = None
        end_positions = None

    query_tokens = tokenizer.tokenize(example.question_text)
    query_tokens = [YES_TOKEN, NO_TOKEN] + query_tokens  # 是否类和问题拼接
    truncated_query = tokenizer.encode(query_tokens, add_special_tokens=False, max_length=max_query_length,
                                       truncation=True)

    sequence_pair_added_tokens = tokenizer.num_special_tokens_to_add(pair=True)
    assert sequence_pair_added_tokens == 3

    added_tokens_num_before_second_sequence = tokenizer.num_special_tokens_to_add(pair=False)
    assert added_tokens_num_before_second_sequence == 2
    span_doc_tokens = all_doc_tokens
    spans = []

    # print("query_tokens:", query_tokens)
    # print("all_doc_tokens:", all_doc_tokens)

    # print("".join(all_doc_tokens))
    # print("start_positions:", start_positions)
    # print("end_positions:", end_positions)
    # 这里使用滑动窗口法
    while len(spans) * doc_stride < len(all_doc_tokens):
        # print(max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,)
        # 以步长为doc_stride进行滑窗
        encoded_dict = tokenizer.encode_plus(
            truncated_query,
            span_doc_tokens,
            max_length=max_seq_length,
            return_overflowing_tokens=True,
            padding="max_length",
            stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
            truncation="only_second",
            return_token_type_ids=True
        )
        # print(span_doc_tokens)
        # print("stride:", max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens)
        # print(tokenizer.convert_ids_to_tokens(encoded_dict['input_ids']))
        # print(len(encoded_dict['input_ids']))
        # print(tokenizer.convert_ids_to_tokens(encoded_dict['overflowing_tokens']))
        # 句子的真实长度
        paragraph_len = min(
            len(all_doc_tokens) - len(spans) * doc_stride,
            max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
        )
        # 不包含[PAD]的token_ids
        if tokenizer.pad_token_id in encoded_dict["input_ids"]:
            non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
        else:
            non_padded_ids = encoded_dict["input_ids"]
        # 重新将ids转换为tokens
        tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)

        token_to_orig_map = {}
        token_to_orig_map[0] = -1
        token_to_orig_map[1] = -1
        token_to_orig_map[2] = -1

        token_is_max_context = {0: True, 1: True, 2: True}
        for i in range(paragraph_len):
            # token在输入[CLS]query[SEP]context[SEP]里面的索引
            index = len(truncated_query) + added_tokens_num_before_second_sequence + i
            # tok_to_orig_index是token在context里面的索引
            # spans的长度表明当前总共有几个片段
            # token_to_orig_map是将index映射到真实的i上
            token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
        # print(token_to_orig_map)
        encoded_dict["paragraph_len"] = paragraph_len
        encoded_dict["tokens"] = tokens
        encoded_dict["token_to_orig_map"] = token_to_orig_map
        encoded_dict["truncated_query_with_special_tokens_length"] = len(
            truncated_query) + added_tokens_num_before_second_sequence
        encoded_dict["token_is_max_context"] = token_is_max_context
        encoded_dict["start"] = len(spans) * doc_stride  # 文本的起始索引
        encoded_dict["length"] = paragraph_len

        # 这里将是否类的标记token_type_ids设置为1,为什么?
        encoded_dict["token_type_ids"][1] = 1
        encoded_dict["token_type_ids"][2] = 1

        # print(encoded_dict["token_type_ids"])
        spans.append(encoded_dict)

        if "overflowing_tokens" not in encoded_dict or len(encoded_dict["overflowing_tokens"]) == 0:
            break
        else:
            span_doc_tokens = encoded_dict["overflowing_tokens"]

    for doc_span_index in range(len(spans)):
        for j in range(spans[doc_span_index]["paragraph_len"]):
            is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
            index = spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
            spans[doc_span_index]["token_is_max_context"][index] = is_max_context

    for span in spans:
        cls_index = span["input_ids"].index(tokenizer.cls_token_id)

        # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
        # p_mask是将问题和SEP对应位置设置为1,其余位置设置为0
        p_mask = np.array(span["token_type_ids"])
        p_mask = np.minimum(p_mask, 1)
        p_mask = 1 - p_mask
        p_mask[np.where(np.array(span["input_ids"]) == tokenizer.sep_token_id)[0]] = 1
        p_mask[cls_index] = 0
        p_mask[1] = 0
        p_mask[2] = 0

        current_start_positions = None
        current_end_positions = None
        span_is_impossible = None
        if is_training:
            current_start_positions = [0 for i in range(max_seq_length)]
            current_end_positions = [0 for i in range(max_seq_length)]
            doc_start = span["start"]
            doc_end = span["start"] + span["length"] - 1  # 文本的截止索引
            doc_offset = len(truncated_query) + added_tokens_num_before_second_sequence  # 偏移量
            for i in range(len(start_positions)):
                start_position = start_positions[i]
                end_position = end_positions[i]
                # 这里重新整合start_position和end_position
                if start_position >= doc_start and end_position  Dataset:
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
    all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
    all_example_indexes = torch.tensor([f.example_index for f in features], dtype=torch.long)
    all_feature_indexes = torch.arange(all_input_ids.size(0), dtype=torch.long)
    if is_training:
        all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
        all_start_labels = torch.tensor([f.start_positions for f in features], dtype=torch.float)
        all_end_labels = torch.tensor([f.end_positions for f in features], dtype=torch.float)
        dataset = TensorDataset(
            all_input_ids,
            all_attention_masks,
            all_token_type_ids,
            all_start_labels,
            all_end_labels,
            all_cls_index,
            all_p_mask,
            all_is_impossible,
            all_example_indexes,
            all_feature_indexes
        )
    else:
        dataset = TensorDataset(
            all_input_ids,
            all_attention_masks,
            all_token_type_ids,
            all_cls_index,
            all_p_mask,
            all_example_indexes,
            all_feature_indexes
        )
    return dataset

def _is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

def _new_check_is_max_context(doc_spans, cur_span_index, position):
"""
    Check if this is the 'max context' doc span for the token.

"""
    # if len(doc_spans) == 1:
    # return True
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span["start"] + doc_span["length"] - 1
        if position < doc_span["start"]:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span["start"]
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index

def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
"""
    Returns tokenized answer spans that better match the annotated answer.

"""
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)

def customize_tokenizer(text: str, do_lower_case=True) -> List[str]:
    temp_x = ""
    for char in text:
        # 在一些特殊字符左右插入两个空格
        if _is_chinese_char(ord(char)) or _is_punctuation(char) or _is_whitespace(char) or _is_control(char):
            temp_x += " " + char + " "
        else:
            temp_x += char
    # 是否将英文大写转换为小写
    if do_lower_case:
        temp_x = temp_x.lower()

    return temp_x.split()  # 这里会使用空格进行切分

def _is_chinese_char(cp):
    """Checks whether CP is the codepoint of a CJK character."""
    # This defines a "chinese character" as anything in the CJK Unicode block:
    #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    #
    # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
    # despite its name. The modern Korean Hangul alphabet is a different block,
    # as is Japanese Hiragana and Katakana. Those alphabets are used to write
    # space-separated words, so they are not treated specially and handled
    # like the all of the other languages.

    if (
            (cp >= 0x4E00 and cp = 0x3400 and cp = 0x20000 and cp = 0x2A700 and cp = 0x2B740 and cp = 0x2B820 and cp = 0xF900 and cp = 0x2F800 and cp

我们需要注意的是总过经过了三个重整的阶段:

  • 第一阶段:先初步将文本进行token化,这一步是去除掉文本中的一些特殊符号及空格等,因此要对答案的起始位置进行校准。
  • 第二阶段:这一步利用tokenizer对每一个字(词)进行token化,由于是wordpiece,会影响句子的长度以及答案,因此也要重新进行校准。
  • 第三阶段:这一步是要整合问题和文本,同时采用滑动窗口法,因此也要重新校准答案在文本中的位置。
<input_ids: 0 4 7 487 [101, 1, 2, 7342, 12124, 1762, 2398, 2128, 6568, 7372, 2551, 1146, 1062, 1385, 1905, 2832, 749, 784, 720, 924, 8043, 102, 9595, 119, 125, 1039, 132, 124, 1161, 808, 510, 3342, 10871, 3118, 802, 2526, 2360, 6589, 11960, 8129, 2824, 2857, 3315, 3428, 4638, 6401, 6390, 4500, 752, 2141, 680, 4415, 4507, 131, 8138, 2399, 8110, 3299, 3189, 117, 6206, 3724, 711, 1071, 704, 1744, 1093, 689, 7213, 6121, 5500, 819, 3300, 7361, 4689, 5852, 6956, 113, 809, 678, 5042, 4917, 114, 8416, 9086, 6587, 3621, 2990, 897, 702, 782, 3867, 928, 6395, 2970, 1358, 2400, 5041, 1355, 1296, 5356, 1384, 8752, 9723, 9131, 8756, 11906, 9446, 8311, 8152, 5373, 5287, 3175, 2466, 3309, 3680, 9649, 8158, 7313, 5632, 1394, 1398, 7555, 3123, 722, 6629, 5635, 3926, 985, 1059, 2622, 3632, 6421, 5276, 2137, 2870, 3612, 818, 862, 671, 6809, 1168, 8188, 1921, 6228, 3125, 4495, 898, 2945, 2190, 6822, 6608, 794, 2496, 2458, 1993, 6631, 6814, 8114, 793, 3313, 1403, 2495, 6820, 1156, 6824, 7444, 2213, 1825, 3144, 6369, 5050, 2902, 1283, 3403, 1114, 7032, 1400, 5815, 2533, 2207, 7583, 1066, 6854, 2382, 3833, 955, 8216, 5023, 1728, 2130, 2252, 721, 1218, 754, 8119, 123, 8132, 4509, 6435, 5164, 102] attention_mask: [1, 1] token_type_ids: [0, 0, cls_index: p_mask: example_index: unique_id: paragraph_len: token_is_max_context: {0: true, 1: 2: 24: false, 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: 35: 36: 37: 38: 39: 40: 41: 42: 43: 44: 45: 46: 47: 48: 49: 50: 51: 52: 53: 54: 55: 56: 57: 58: 59: 60: 61: 62: 63: 64: 65: 66: 67: 68: 69: 70: 71: 72: 73: 74: 75: 76: 77: 78: 79: 80: 81: 82: 83: 84: 85: 86: 87: 88: 89: 90: 91: 92: 93: 94: 95: 96: 97: 98: 99: 100: 101: 102: 103: 104: 105: 106: 107: 108: 109: 110: 111: 112: 113: 114: 115: 116: 117: 118: 119: 120: 121: 122: 123: 124: 125: 126: 127: 128: 129: 130: 131: 132: 133: 134: 135: 136: 137: 138: 139: 140: 141: 142: 143: 144: 145: 146: 147: 148: 149: 150: 151: 152: 153: 154: 155: 156: 157: 158: 159: 160: 161: 162: 163: 164: 165: 166: 167: 168: 169: 170: 171: 172: 173: 174: 175: 176: 177: 178: 179: 180: 181: 182: 183: 184: 185: 186: 187: 188: 189: 190: 191: 192: 193: 194: 195: 196: 197: 198: 199: 200: 201: 202: 203: 204: 205: 206: 207: 208: 209: 210: 211: 212: 213: 214: 215: 216: 217: 218: 219: 220: 221: 222: 223: 224: 225: 226: 227: 228: 229: 230: 231: 232: 233: 234: 235: 236: 237: 238: 239: 240: 241: 242: 243: 244: 245: 246: 247: 248: 249: 250: 251: 252: 253: 254: 255: 256: 257: 258: 259: 260: 261: 262: 263: 264: 265: 266: 267: 268: 269: 270: 271: 272: 273: 274: 275: 276: 277: 278: 279: 280: 281: 282: 283: 284: 285: 286: 287: 288: 289: 290: 291: 292: 293: 294: 295: 296: 297: 298: 299: 300: 301: 302: 303: 304: 305: 306: 307: 308: 309: 310: 311: 312: 313: 314: 315: 316: 317: 318: 319: 320: 321: 322: 323: 324: 325: 326: 327: 328: 329: 330: 331: 332: 333: 334: 335: 336: 337: 338: 339: 340: 341: 342: 343: 344: 345: 346: 347: 348: 349: 350: 351: 352: 353: 354: 355: 356: 357: 358: 359: 360: 361: 362: 363: 364: 365: 366: 367: 368: 369: 370: 371: 372: 373: 374: 375: 376: 377: 378: 379: 380: 381: 382: 383: 384: 385: 386: 387: 388: 389: 390: 391: 392: 393: 394: 395: 396: 397: 398: 399: 400: 401: 402: 403: 404: 405: 406: 407: 408: 409: 410: 411: 412: 413: 414: 415: 416: 417: 418: 419: 420: 421: 422: 423: 424: 425: 426: 427: 428: 429: 430: 431: 432: 433: 434: 435: 436: 437: 438: 439: 440: 441: 442: 443: 444: 445: 446: 447: 448: 449: 450: 451: 452: 453: 454: 455: 456: 457: 458: 459: 460: 461: 462: 463: 464: 465: 466: 467: 468: 469: 470: 471: 472: 473: 474: 475: 476: 477: 478: 479: 480: 481: 482: 483: 484: 485: 486: 487: 488: 489: 490: 491: 492: 493: 494: 495: 496: 497: 498: 499: 500: 501: 502: 503: 504: 505: 506: 507: 508: 509: 510: false} tokens: ['[cls]', '[unused1]', '[unused2]', '阮', 'x4', '在', '平', '安', '财', '险', '徽', '分', '公', '司', '处', '投', '了', '什', '么', '保', '?', '[sep]', '##92', '.', '4', '元', ';', '3', '判', '令', '、', '杨', 'x5', '支', '付', '律', '师', '费', '690', '##0', '承', '担', '本', '案', '的', '诉', '讼', '用', '事', '实', '与', '理', '由', ':', '2013', '年', '12', '月', '日', ',', '要', '求', '为', '其', '中', '国', '农', '业', '银', '行', '股', '份', '有', '限', '省', '营', '部', '(', '以', '下', '简', '称', ')', '94', '##000', '贷', '款', '提', '供', '个', '人', '消', '信', '证', '接', '受', '并', '签', '发', '单', '编', '号', '125', '##94', '##07', '##26', '##010', '##87', '##10', '##3', '缴', '纳', '方', '式', '期', '每', '178', '##6', '间', '自', '合', '同', '项', '放', '之', '起', '至', '清', '偿', '全', '息', '止', '该', '约', '定', '拖', '欠', '任', '何', '一', '达', '到', '80', '天', '视', '故', '生', '依', '据', '对', '进', '赔', '从', '当', '开', '始', '超', '过', '30', '仍', '未', '向', '归', '还', '则', '违', '需', '尚', '基', '数', '计', '算', '按', '千', '标', '准', '金', '后', '获', '得', '小', '额', '共', '途', '常', '活', '借', '36', '等', '因', '完', '履', '义', '务', '于', '2015', '2', '25', '申', '请', '索', '[sep]'] token_to_orig_map: -1, 126, 127, 128, 129, 130, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598} start_positions: 0] end_positions: is_impossible: false>
</input_ids:>

最后需要注意的是:怎么将是否类、单个片段类和多个片段类进行统一的标识,以及输入中的token_type_ids是怎么进行设置的。
最有有一个处理文件将处理后的数据保存下来,避免每次重复进行处理,消耗时间:

"""
数据处理相关代码
"""
import argparse
import json

from transformers import PreTrainedTokenizer, BertTokenizer
from data_process_utils import *
import gzip
import pickle
import os
from os.path import join
import logging

def convert_and_write(args, tokenizer: PreTrainedTokenizer, file, examples_fn, features_fn, is_training):
    logging.info(f"Reading examples from :{file} ...")
    example_list = read_examples(file, is_training=is_training)
    logging.info(f"Total examples:{len(example_list)}")

    logging.info(f"Start converting examples to features.")
    feature_list = convert_examples_to_features(example_list, tokenizer, args, is_training)
    logging.info(f"Total features:{len(feature_list)}")

    logging.info(f"Converting complete, writing examples and features to file.")
    with gzip.open(join(args.output_path, examples_fn), "wb") as file:
        pickle.dump(example_list, file)
    with gzip.open(join(args.output_path, features_fn), "wb") as file:
        pickle.dump(feature_list, file)

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--input_file",
        type=str,
        required=True,
        help="The file to be processed."
    )

    parser.add_argument(
        "--for_training",
        action="store_true",
        help="Process for training or not."
    )

    parser.add_argument(
        "--output_prefix",
        type=str,
        required=True,
        help="The prefix of output file's name."
    )

    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model."
    )

    parser.add_argument(
        "--tokenizer_path",
        type=str,
        required=True,
        help="Path to tokenizer which will be used to tokenize text.(ElectraTokenizer)"
    )

    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. "
             "Longer will be truncated, and shorter will be padded."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help="The maximum number of tokens for the question. Questions longer will be truncated to the length."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help="When splitting up a long document into chunks, how much stride to take between chunks."
    )

    parser.add_argument(
        "--output_path",
        default="./processed_data/",
        type=str,
        help="Output path of the constructed examples and features."
    )

    args = parser.parse_args()
    args.max_query_length += 2  # position for token yes and no
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
    )

    logging.info("All input parameters:")
    print(json.dumps(vars(args), sort_keys=False, indent=2))

    tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path)

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    convert_and_write(args, tokenizer, args.input_file, args.output_prefix + "_examples.pkl.gz",
                      args.output_prefix + "_features.pkl.gz", args.for_training)

if __name__ == "__main__":
    main()

运行指令:

python data_process.py --input_file data_sample/cail2021_mrc_small.json --output_prefix cail2021_mrc_small --tokenizer_path model_hub/chinese-bert-wwm-ext --max_seq_length 512 --max_query_length 64 --doc_stride 128 --do_lower_case --for_training

Original: https://www.cnblogs.com/xiximayou/p/16359133.html
Author: 西西嘛呦
Title: CAIL2021-阅读理解任务-数据预处理模块(二)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/554114/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球