mindformers.core.metric.metric 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file was refer to project:
# https://github.com/lonePatient/daguan_2019_rank9/blob/master/pydatagrand/train/ner_utils.py
# ============================================================================
"""MindFormer Self-Define Metric."""
import os
import sys
import re
import collections
import json
import math
import string
import shutil
import six
import jieba
import numpy as np
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

import mindspore.nn as nn
import mindspore as ms
from mindspore.ops import operations as P
from mindspore.communication import get_group_size, get_rank
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.models import BasicTokenizer
from mindformers.core.loss import CrossEntropyLoss
from ...dataset.labels import cluener_labels

__all__ = ['EntityScore', 'SQuADMetric', 'PerplexityMetric', 'ADGENMetric', 'PromptAccMetric', 'EmF1Metric']


[文档]@MindFormerRegister.register(MindFormerModuleType.METRIC) class EntityScore(nn.Metric): """Compute the f1, precision and recall score of each entity""" def __init__(self): super(EntityScore, self).__init__() self.label2id = {label: label_id for label_id, label in enumerate(cluener_labels)} self.id2label = {label_id: label for label, label_id in self.label2id.items()} self.clear()
[文档] def clear(self): "Initialization." self.origins = [] self.founds = [] self.rights = []
[文档] def update(self, *inputs): """Update results for every batch""" batch_logits = inputs[0].asnumpy() batch_label_ids = inputs[1].asnumpy() batch_pred_ids = np.argmax(batch_logits, axis=2).tolist() pred_paths = [[self.id2label[id_] for id_ in pred_ids] for pred_ids in batch_pred_ids] label_paths = [[self.id2label[id_] for id_ in label_ids] for label_ids in batch_label_ids] for label_path, pre_path in zip(label_paths, pred_paths): label_entities = self.get_entities_bios(label_path) pred_entities = self.get_entities_bios(pre_path) self.origins.extend(label_entities) self.founds.extend(pred_entities) self.rights.extend([pred_entity for pred_entity in pred_entities if pred_entity in label_entities])
[文档] def eval(self): """Compute final results.""" class_info = {} origin_counter = collections.Counter([x[0] for x in self.origins]) found_counter = collections.Counter([x[0] for x in self.founds]) right_counter = collections.Counter([x[0] for x in self.rights]) for type_, count in origin_counter.items(): origin = count found = found_counter.get(type_, 0) right = right_counter.get(type_, 0) recall, precision, f1 = self.compute(origin, found, right) class_info[type_] = {"precision": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} origin = len(self.origins) found = len(self.founds) right = len(self.rights) recall, precision, f1 = self.compute(origin, found, right) return {"precision": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}, class_info
[文档] def compute(self, origin, found, right): """Compute f1, precision and recall.""" recall = 0 if origin == 0 else (right / origin) precision = 0 if found == 0 else (right / found) f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) return recall, precision, f1
[文档] def get_entities_bios(self, seq): """Get entities from sequence.""" chunks = [] chunk = [-1, -1, -1] for indx, tag in enumerate(seq): if tag.startswith("S-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[2] = indx chunk[0] = tag.split('-')[1] chunks.append(chunk) chunk = [-1, -1, -1] if tag.startswith("B-"): if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] chunk[1] = indx chunk[0] = tag.split('-')[1] elif tag.startswith('I-') and chunk[1] != -1: entity_type = tag.split('-')[1] if entity_type == chunk[0]: chunk[2] = indx if indx == len(seq) - 1: chunks.append(chunk) else: if chunk[2] != -1: chunks.append(chunk) chunk = [-1, -1, -1] return chunks
[文档]@MindFormerRegister.register(MindFormerModuleType.METRIC) class SQuADMetric(nn.Metric): """Compute the f1, precision and recall score of each entity""" def __init__(self, dataset_dir, n_best_size=20, max_answer_len=30, do_lower_case=True, temp_file_dir="./squad_temp"): self.outputs = [] self.temp_file_dir = temp_file_dir temp_examples_file = os.path.join(temp_file_dir, "temp_examples.json") temp_features_file = os.path.join(temp_file_dir, "temp_features.json") self.all_examples = self._load_temp_data(temp_examples_file) self.all_features = self._load_temp_data(temp_features_file) self.dev_file_path = os.path.join(dataset_dir, "dev-v1.1.json") self.basic_tokenizer = BasicTokenizer(do_lower_case) self.n_best_size = n_best_size self.max_answer_len = max_answer_len
[文档] def clear(self): """Clearing the internal evaluation result.""" return
[文档] def update(self, *inputs): """Update results for every batch""" ids = inputs[0].asnumpy() start = inputs[1].asnumpy() end = inputs[2].asnumpy() batch_size = len(ids) RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) for i in range(batch_size): unique_id = int(ids[i]) start_logits = [float(x) for x in start[i].flat] end_logits = [float(x) for x in end[i].flat] self.outputs.append(RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits))
[文档] def eval(self): """Compute final result""" predictions = self._get_predictions() with open(self.dev_file_path) as ds: dataset_json = json.load(ds) dataset = dataset_json['data'] f1 = exact_match = total = 0 for article in dataset: for paragraph in article['paragraphs']: for qa in paragraph['qas']: total += 1 if qa['id'] not in predictions: message = 'Unanswered question ' + qa['id'] + \ ' will receive score 0.' print(message, file=sys.stderr) continue ground_truths = list(map(lambda x: x['text'], qa['answers'])) if not ground_truths: continue prediction = predictions[qa['id']] exact_match += self._metric_max_over_ground_truths( self._exact_match_score, prediction, ground_truths) f1 += self._metric_max_over_ground_truths( self._f1_score, prediction, ground_truths) exact_match = 100.0 * exact_match / total f1 = 100.0 * f1 / total self._remove_temp_data() return {'exact_match': exact_match, 'f1': f1}
def _remove_temp_data(self): shutil.rmtree(self.temp_file_dir) def _load_temp_data(self, temp_file_path): with open(temp_file_path, "r", encoding="utf-8") as f: data = [] for line in f.readlines(): data.append(json.loads(line.strip())) return data def _normalize_answer(self, s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def _f1_score(self, prediction, ground_truth): """calculate f1 score""" prediction_tokens = self._normalize_answer(prediction).split() ground_truth_tokens = self._normalize_answer(ground_truth).split() common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def _exact_match_score(self, prediction, ground_truth): return self._normalize_answer(prediction) == self._normalize_answer(ground_truth) def _metric_max_over_ground_truths(self, metric_fn, prediction, ground_truths): scores_for_ground_truths = [] for ground_truth in ground_truths: score = metric_fn(prediction, ground_truth) scores_for_ground_truths.append(score) return max(scores_for_ground_truths) def _get_predictions(self): """Get final predictions""" example_index_to_features = collections.defaultdict(list) for feature in self.all_features: example_index_to_features[feature["example_index"]].append(feature) unique_id_to_result = {} for result in self.outputs: unique_id_to_result[result.unique_id] = result all_predictions = collections.OrderedDict() for (example_index, example) in enumerate(self.all_examples): features = example_index_to_features[example_index] prelim_predictions = self._get_prelim_predictions(features, unique_id_to_result) nbest = self._get_nbest(prelim_predictions, features, example) total_scores = [] best_non_null_entry = None for entry in nbest: total_scores.append(entry.start_logit + entry.end_logit) if not best_non_null_entry: if entry.text: best_non_null_entry = entry probs = self._compute_softmax(total_scores) nbest_json = [] for (i, entry) in enumerate(nbest): output = collections.OrderedDict() output["text"] = entry.text output["probability"] = probs[i] output["start_logit"] = entry.start_logit output["end_logit"] = entry.end_logit nbest_json.append(output) assert len(nbest_json) >= 1 all_predictions[example["qas_id"]] = nbest_json[0]["text"] return all_predictions def _get_prelim_predictions(self, features, unique_id_to_result): """get prelim predictions""" _PrelimPrediction = collections.namedtuple( "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) prelim_predictions = [] # keep track of the minimum score of null start+end of position 0 for (feature_index, feature) in enumerate(features): if feature["unique_id"] not in unique_id_to_result: continue result = unique_id_to_result[feature["unique_id"]] start_indexes = self._get_best_indexes(result.start_logits) end_indexes = self._get_best_indexes(result.end_logits) # if we could have irrelevant answers, get the min score of irrelevant for start_index in start_indexes: for end_index in end_indexes: # We could hypothetically create invalid predictions, e.g., predict # that the start of the span is in the question. We throw out all # invalid predictions. if start_index >= len(feature["tokens"]): continue if end_index >= len(feature["tokens"]): continue if str(start_index) not in feature["token_to_orig_map"]: continue if str(end_index) not in feature["token_to_orig_map"]: continue if not feature["token_is_max_context"].get(str(start_index), False): continue if end_index < start_index: continue length = end_index - start_index + 1 if length > self.max_answer_len: continue prelim_predictions.append( _PrelimPrediction( feature_index=feature_index, start_index=start_index, end_index=end_index, start_logit=result.start_logits[start_index], end_logit=result.end_logits[end_index])) prelim_predictions = sorted( prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) return prelim_predictions def _get_nbest(self, prelim_predictions, features, example): """get nbest predictions""" _NbestPrediction = collections.namedtuple( "NbestPrediction", ["text", "start_logit", "end_logit"]) seen_predictions = {} nbest = [] for pred in prelim_predictions: if len(nbest) >= self.n_best_size: break feature = features[pred.feature_index] if pred.start_index > 0: # this is a non-null prediction tok_tokens = feature["tokens"][pred.start_index:(pred.end_index + 1)] orig_doc_start = feature["token_to_orig_map"][str(pred.start_index)] orig_doc_end = feature["token_to_orig_map"][str(pred.end_index)] orig_tokens = example["doc_tokens"][orig_doc_start:(orig_doc_end + 1)] tok_text = " ".join(tok_tokens) # De-tokenize WordPieces that have been split off. tok_text = tok_text.replace(" ##", "") tok_text = tok_text.replace("##", "") # Clean whitespace tok_text = tok_text.strip() tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) final_text = self._get_final_text(tok_text, orig_text) if final_text in seen_predictions: continue seen_predictions[final_text] = True else: final_text = "" seen_predictions[final_text] = True nbest.append( _NbestPrediction( text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit)) # In very rare edge cases we could have no valid predictions. So we # just create a nonce prediction in this case to avoid failure. if not nbest: nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) assert len(nbest) >= 1 return nbest def _compute_softmax(self, scores): """Compute softmax probability over raw logits.""" if not scores: return [] max_score = None for score in scores: if max_score is None or score > max_score: max_score = score exp_scores = [] total_sum = 0.0 for score in scores: x = math.exp(score - max_score) exp_scores.append(x) total_sum += x probs = [] for score in exp_scores: probs.append(score / total_sum) return probs def _get_final_text(self, pred_text, orig_text): """Project the tokenized prediction back to the original text.""" def _strip_spaces(text): ns_chars = [] ns_to_s_map = collections.OrderedDict() for (i, c) in enumerate(text): if c == " ": continue ns_to_s_map[len(ns_chars)] = i ns_chars.append(c) ns_text = "".join(ns_chars) return (ns_text, ns_to_s_map) tok_text = " ".join(self.basic_tokenizer.tokenize(orig_text)) start_position = tok_text.find(pred_text) if start_position == -1: return orig_text end_position = start_position + len(pred_text) - 1 (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) if len(orig_ns_text) != len(tok_ns_text): return orig_text tok_s_to_ns_map = {} for (i, tok_index) in six.iteritems(tok_ns_to_s_map): tok_s_to_ns_map[tok_index] = i orig_start_position = None if start_position in tok_s_to_ns_map: ns_start_position = tok_s_to_ns_map[start_position] if ns_start_position in orig_ns_to_s_map: orig_start_position = orig_ns_to_s_map[ns_start_position] if orig_start_position is None: return orig_text orig_end_position = None if end_position in tok_s_to_ns_map: ns_end_position = tok_s_to_ns_map[end_position] if ns_end_position in orig_ns_to_s_map: orig_end_position = orig_ns_to_s_map[ns_end_position] if orig_end_position is None: return orig_text output_text = orig_text[orig_start_position:(orig_end_position + 1)] return output_text def _get_best_indexes(self, logits): """Get the n-best logits from a list.""" index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) best_indexes = [] for (i, score) in enumerate(index_and_score): if i >= self.n_best_size: break best_indexes.append(score[0]) return best_indexes
[文档]@MindFormerRegister.register(MindFormerModuleType.METRIC) class PerplexityMetric(nn.Metric): """Compute the loss and PPL of each entity""" def __init__(self): super(PerplexityMetric, self).__init__() self.num_data = None self.total_loss = None self.loss = CrossEntropyLoss() self.pipeline_stages = ms.get_auto_parallel_context('pipeline_stages') self.pipeline_parallel = self.pipeline_stages > 1 self.rank_id = 0 self.device_num = 1 self.cast = P.Cast() self.reshape = P.Reshape() self.not_equal = P.NotEqual() self.sub = P.Sub() if self.pipeline_parallel: self.rank_id = get_rank() self.device_num = get_group_size() per_stage_device_num = self.device_num // self.pipeline_stages stage_id = self.rank_id // per_stage_device_num self.is_last_stage = (stage_id == self.pipeline_stages - 1) self.parallel_mode = ms.get_auto_parallel_context("parallel_mode") self.full_batch = ms.get_auto_parallel_context("full_batch") self.auto_parallel = self.parallel_mode in ['semi_auto_parallel', 'auto_parallel']
[文档] def clear(self): """Clearing the internal evaluation result.""" self.num_data = 0 self.total_loss = 0.0
[文档] def update(self, *inputs): """Update results for every batch""" if self.pipeline_parallel: if not self.is_last_stage: return if self.auto_parallel: ms.context.set_auto_parallel_context(parallel_mode='data_parallel', full_batch=False) logits, labels, input_mask = inputs[0], inputs[1], inputs[2] # input_mask was added 1 in GPT2LMModel to avoid allgather issue in Mindspore1.10 input_mask = self.sub(input_mask, 1) batch_size, seq_length, _ = logits.shape logits = self.reshape(logits[::, :-1, ::], (batch_size * (seq_length - 1), -1)) labels = self.reshape(labels[::, 1:], (-1,)) input_mask = self.reshape(input_mask[::, 1:], (-1,)) loss = self.loss(logits, labels, input_mask) loss = float(loss.asnumpy()) self.total_loss += loss self.num_data += 1 if self.auto_parallel: ms.set_auto_parallel_context(parallel_mode=self.parallel_mode, full_batch=True, pipeline_stages=self.pipeline_stages) else: logits, labels, input_mask = inputs[0], inputs[1], inputs[2] batch_size, seq_length, _ = logits.shape logits = self.reshape(logits[::, :-1, ::], (batch_size * (seq_length - 1), -1)) labels = self.reshape(labels[::, 1:], (-1,)) input_mask = self.reshape(input_mask[::, 1:], (-1,)) loss = self.loss(logits, labels, input_mask) loss = float(loss.asnumpy()) self.total_loss += loss self.num_data += 1
[文档] def eval(self): """Compute final result""" if self.pipeline_parallel and not self.is_last_stage: return None avg_loss = float(self.total_loss / self.num_data) result = {"loss": avg_loss, "PPL": math.exp(avg_loss)} if self.pipeline_parallel: print("Average Loss and PPL Metric:", result) return result
[文档]@MindFormerRegister.register(MindFormerModuleType.METRIC) class ADGENMetric(nn.Metric): """Compute the f1, precision and recall score of each entity""" def __init__(self): super(ADGENMetric, self).__init__() self.score_dict = { "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": [] } def clear(self): self.score_dict = { "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": [] }
[文档] def update(self, *inputs): """Update results for every batch""" preds = inputs[0] # list[numpy] labels = inputs[1] # numpy if isinstance(preds, tuple): preds = preds[0] for pred, label in zip(preds, labels): print(f"pred is:\n {pred}\n", f"label is:\n {label}") hypothesis = list(jieba.cut(pred)) reference = list(jieba.cut(label)) rouge = Rouge() scores = rouge.get_scores(' '.join(hypothesis), ' '.join(reference)) result = scores[0] for k, v in result.items(): self.score_dict[k].append(round(v["f"] * 100, 4)) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))
[文档] def eval(self): """Compute final result""" for k, v in self.score_dict.items(): self.score_dict[k] = float(np.mean(v)) print('metric: ADGENMetric\n' + f'rouge-1: {self.score_dict["rouge-1"]:.4f}\n' + f'rouge-2: {self.score_dict["rouge-2"]:.4f}\n' + f'rouge-l: {self.score_dict["rouge-l"]:.4f}\n' + f'bleu-4: {self.score_dict["bleu-4"]:.4f}') return self.score_dict
@MindFormerRegister.register(MindFormerModuleType.METRIC) class PromptAccMetric(nn.Metric): r""" Computes the prompt acc of each entity. The prompt acc is the accuracy of text classification base on building prompt. The accurate index is the index of the prompt which has the minimum perplexity. 1. Build the prompt for this metric is described as follows: 这是关于**体育**的文章:$passage 这是关于**文化**的文章:$passage 2. Computes perplexity of each generated context based on prompt. Perplexity is a measurement about how well a probability distribution or a model predicts a sample. A low perplexity indicates the model can predict the sample well. The function is shown as follows: .. math:: PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}} Where :math:`w` represents words in corpus. 3. Compute classification result by choosing the index of the prompt which has the minimum perplexity. 4. Count the number of correctly classified and the total number of samples and compute the acc as follows: .. math:: \text{accuracy} =\frac{\text{correct_sample_nums}}{\text{total_sample_nums}} """ def __init__(self): super(PromptAccMetric, self).__init__() self.num_data = None self.total_acc_num = None self.loss = CrossEntropyLoss() self.pipeline_stages = ms.get_auto_parallel_context('pipeline_stages') self.pipeline_parallel = self.pipeline_stages > 1 self.last_card_id = 0 self.rank_id = 0 self.device_num = 1 self.cast = P.Cast() self.reshape = P.Reshape() self.equal = P.Equal() self.softmax = P.Softmax() self.argmin = P.Argmin() self.sum = P.ReduceSum() if self.pipeline_parallel: self.rank_id = get_rank() self.device_num = get_group_size() per_stage_device_num = self.device_num // self.pipeline_stages stage_id = self.rank_id // per_stage_device_num self.is_last_stage = (stage_id == self.pipeline_stages - 1) self.parallel_mode = ms.get_auto_parallel_context("parallel_mode") self.full_batch = ms.get_auto_parallel_context("full_batch") self.auto_parallel = self.parallel_mode in ['semi_auto_parallel', 'auto_parallel'] def clear(self): """Clearing the internal evaluation result.""" self.num_data = 0 self.total_acc_num = 0 def calculate_circle(self, *inputs): """The main calculate logic.""" logits, input_ids, input_mask, labels = inputs[0], inputs[1], inputs[2], inputs[3] batch_size, num_labels, seq_length, _ = logits.shape logits = self.reshape(logits, (batch_size*num_labels, seq_length, -1)) ppl_list = [] for index in range(batch_size*num_labels): sub_logits, sub_tokens, sub_mask_list = logits[index], input_ids[index], input_mask[index] sub_logits = sub_logits[:-1, ::] sub_tokens = sub_tokens[1:] sub_mask_list = sub_mask_list[1:] loss = self.loss(sub_logits, sub_tokens, sub_mask_list) loss = float(loss.asnumpy()) ppl_list.append(loss) # smaller, better ppl_ms = ms.Tensor(ppl_list, dtype=ms.float32) ppl_ms = self.reshape(ppl_ms, (batch_size, num_labels)) ppl_ms = self.cast(self.argmin(ppl_ms), ms.int32) label = self.reshape(labels, (-1,)) cur_acc_num = self.cast(self.equal(ppl_ms, label), ms.float16).sum().asnumpy() self.num_data += batch_size self.total_acc_num += cur_acc_num def update(self, *inputs): """Update results for every batch""" if self.pipeline_parallel: if not self.is_last_stage: return if self.auto_parallel: ms.context.set_auto_parallel_context(parallel_mode='data_parallel', full_batch=False) self.calculate_circle(*inputs) if self.auto_parallel: ms.set_auto_parallel_context(parallel_mode=self.parallel_mode, full_batch=True, pipeline_stages=self.pipeline_stages) else: self.calculate_circle(*inputs) print("Current data num is {}, total acc num is {}, ACC is {}".format( self.num_data, self.total_acc_num, "%.3f" % (self.total_acc_num / self.num_data))) return def eval(self): """Compute final result""" if self.pipeline_parallel and not self.is_last_stage: return None acc_rate = float(self.total_acc_num / self.num_data) result = {"Acc": acc_rate} print(f"Acc: {('%.3f' % result.get('Acc', 0))}, total_acc_num: {self.total_acc_num}, " f"total_num: {self.num_data}") return result @MindFormerRegister.register(MindFormerModuleType.METRIC) class EmF1Metric(nn.Metric): """ Compute the Em/F1 scores of examples. Em score is the prediction exact matches the labels except the punctuations. For example, the question is "河南的省会是哪里?" and the label is "郑州市", when prediction is "郑州市", Em score is 100; when prediction is "郑州市。", Em score is 100; when prediction is "郑州", Em score is 0. F1 score is calculated as 2*precision*recall/(precision+recall), the precision and recall are calculated as precision = lcs_length/len(prediction_segment), recall = lcs_length/len(label_segment), lcs_length is the length of the longest common subsequence. """ def __init__(self): super(EmF1Metric, self).__init__() self.gens = None self.labels = None self.metrics = None self.num_data = None def clear(self): """Clearing the internal evaluation result.""" self.gens = [] self.labels = [] self.metrics = { 'Em': 0.0, 'F1': 0.0 } self.num_data = 0 def update(self, *inputs): """Update results for every batch""" gen, label = inputs[0], inputs[1] for i in range(len(gen)): gen[i] = gen[i].strip() gen[i] = gen[i].split("\n")[0] print(f"pred is:\n {gen}\n", f"label is:\n {label}") self.gens.extend(gen) self.labels.extend(label) self.num_data += len(gen) result, current_count = self.evaluate_pairs(gen, label) print("The F1/Em of this example is: ", result) if self.num_data % 10 == 0: result, current_count = self.evaluate_pairs(self.gens, self.labels) print(f"F1 score: {result.get('F1', 0)}, Em score: {result.get('Em', 0)}, current_count: {current_count}") def eval(self): """Compute final result""" result, total_count = self.evaluate_pairs(self.gens, self.labels) print(f"F1 score: {result.get('F1', 0)}, Em score: {result.get('Em', 0)}, total_count: {total_count}") return result def mixed_segmentation(self, in_str, rm_punc=False): """cut input for calculating lcs""" in_str = str(in_str).lower().strip() segs_out = [] temp_str = "" sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', ')', '-', '~', '『', '』'] for char in in_str: if rm_punc and char in sp_char: continue if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char: if temp_str != "": ss = list(jieba.cut(temp_str)) segs_out.extend(ss) temp_str = "" segs_out.append(char) else: temp_str += char if temp_str != "": ss = list(jieba.cut(temp_str)) segs_out.extend(ss) return segs_out def remove_punctuation(self, in_str): """remove punctuations in inputs""" in_str = str(in_str).lower().strip() sp_char = ['-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(', ')', '-', '~', '『', '』'] out_segs = [] for char in in_str: if char in sp_char: continue else: out_segs.append(char) return ''.join(out_segs) def find_lcs(self, s1, s2): """calculate the length of lcs""" m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)] mmax = 0 p = 0 for i in range(len(s1)): for j in range(len(s2)): if s1[i] == s2[j]: m[i + 1][j + 1] = m[i][j] + 1 if m[i + 1][j + 1] > mmax: mmax = m[i + 1][j + 1] p = i + 1 return s1[p - mmax:p], mmax def calc_f1_score(self, answers, prediction): """calculate f1 score""" f1_scores = [] for ans in answers: ans_segs = self.mixed_segmentation(ans, rm_punc=True) prediction_segs = self.mixed_segmentation(prediction, rm_punc=True) _, lcs_len = self.find_lcs(ans_segs, prediction_segs) if lcs_len == 0: f1_scores.append(0) continue precision = 1.0 * lcs_len / len(prediction_segs) recall = 1.0 * lcs_len / len(ans_segs) f1 = (2 * precision * recall) / (precision + recall) f1_scores.append(f1) return max(f1_scores) def calc_em_score(self, answers, prediction): """calculate em score""" em = 0 for ans in answers: ans_ = self.remove_punctuation(ans) prediction_ = self.remove_punctuation(prediction) if ans_ == prediction_: em = 1 break return em def evaluate_pairs(self, pred_, ans_): """calculate metric""" f1 = 0 em = 0 total_count = 0 for (prediction, answer) in zip(pred_, ans_): total_count += 1 f1 += self.calc_f1_score([answer], prediction) em += self.calc_em_score([answer], prediction) if total_count > 0: f1_score = 100.0 * f1 / total_count em_score = 100.0 * em / total_count result = {'F1': f1_score, 'Em': em_score} else: print("total_count is zero") result = {} return result, total_count