# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Keyword Generation Dataset."""
import copy
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.transforms.c_transforms as C
import numpy as np
from mindformers.models.base_tokenizer import BaseTokenizer
from mindformers.dataset.base_dataset import BaseDataset
from mindformers.dataset.dataloader import build_dataset_loader
from mindformers.models.build_tokenizer import build_tokenizer
from mindformers.tools.logger import logger
from mindformers.tools.register import MindFormerModuleType, MindFormerRegister
from mindformers.version_control import get_dataset_map
[文档]@MindFormerRegister.register(MindFormerModuleType.DATASET)
class KeyWordGenDataset(BaseDataset):
"""Keyword generation dataset.
Examples:
>>> from mindformers.dataset.dataloader.adgen_dataloader import ADGenDataLoader
>>> from mindformers.dataset import build_dataset
>>> from mindformers import MindFormerConfig
>>> cfg = MindFormerConfig("./configs/glm/run_glm_6b_finetune.yaml")
>>> dataset = build_dataset(cfg.eval_dataset_task)
>>> for item in dataset.create_dict_iterator():
>>> print(item)
>>> break
"""
def __new__(cls, dataset_config: dict = None):
logger.info("Now Create Keyword Generation Dataset.")
cls.init_dataset_config(dataset_config)
rank_id = int(os.getenv("RANK_ID", "0"))
device_num = int(os.getenv("RANK_SIZE", "1"))
rank_id, device_num = cls._check_device_rank_for_parallel(rank_id, device_num)
dataset_config.rank_id = rank_id
dataset_config.device_num = device_num
if isinstance(dataset_config.tokenizer, BaseTokenizer):
cls.tokenizer = dataset_config.tokenizer
else:
cls.tokenizer = build_tokenizer(dataset_config.tokenizer)
cls.ignore_pad_token_for_loss = dataset_config.ignore_pad_token_for_loss
cls.max_source_length = dataset_config.max_source_length
cls.max_target_length = dataset_config.max_target_length
cls.max_seq_length = cls.max_source_length + cls.max_target_length
cls.phase = dataset_config.data_loader.phase
cls.version = dataset_config.data_loader.version
if dataset_config.data_loader.type != 'MindDataset':
dataset = cls._process_raw_text_data(dataset_config)
else:
dataset = cls._process_mindrecord_data(dataset_config)
dataset = dataset.batch(dataset_config.batch_size,
drop_remainder=dataset_config.drop_remainder,
num_parallel_workers=dataset_config.num_parallel_workers)
dataset = dataset.repeat(dataset_config.repeat)
type_cast_op = C.TypeCast(mstype.int32)
for input_arg in dataset_config.input_columns:
dataset = get_dataset_map(dataset, type_cast_op, input_columns=input_arg)
return dataset
@classmethod
def _tokenizer_map(cls, dataset, tokenizer_config):
"""Maps the tokenizer on the source and the output"""
phase = cls.phase
version = cls.version if cls.version else 1
logger.info("Start tokenize on the dataset using tokenizer: %s", tokenizer_config)
if version == 2:
train_dataset_function = cls.train_dataset_functionv2
train_output_columns = ["input_ids", "labels"]
eval_dataset_function = cls.eval_dataset_functionv2
else:
train_dataset_function = cls.train_dataset_function
train_output_columns = ["input_ids", "labels", "position_ids", "attention_mask"]
eval_dataset_function = cls.eval_dataset_function
input_columns = ["prompt", "answer"]
eval_output_columns = ["input_ids", "labels"]
# Avoid to_json error when summary monitor is opened
def train_dataset_func(prompt, answer):
return train_dataset_function(prompt, answer)
def eval_dataset_func(prompt, answer):
return eval_dataset_function(prompt, answer)
if phase == "train":
dataset = get_dataset_map(dataset,
train_dataset_func,
input_columns=input_columns,
output_columns=train_output_columns)
dataset = dataset.project(columns=train_output_columns)
if phase == "eval":
dataset = get_dataset_map(dataset,
eval_dataset_func,
input_columns=input_columns,
output_columns=eval_output_columns)
dataset = dataset.project(columns=eval_output_columns)
return dataset
@classmethod
def _process_raw_text_data(cls, dataset_config):
"""Process the text data"""
dataset_dir = dataset_config.data_loader.pop("dataset_dir")
dataset = build_dataset_loader(
dataset_config.data_loader, default_args={'dataset_dir': dataset_dir,
'num_shards': dataset_config.device_num,
'shard_id': dataset_config.rank_id})
dataset = cls._tokenizer_map(dataset, dataset_config.tokenizer)
return dataset
@classmethod
def _process_mindrecord_data(cls, dataset_config):
"""Process the mindrecord data"""
dataset_config = copy.deepcopy(dataset_config)
dataset_files = []
if dataset_config.data_loader.dataset_dir:
data_dir = dataset_config.data_loader.pop("dataset_dir")
if os.path.isdir(data_dir):
for r, _, f in os.walk(data_dir):
for file in f:
if file.endswith(".mindrecord"):
dataset_files.append(os.path.join(r, file))
dataset_files.sort()
else:
if data_dir.endswith(".mindrecord"):
dataset_files = data_dir
elif dataset_config.data_loader.dataset_files:
dataset_files = dataset_config.data_loader.dataset_files
if isinstance(dataset_files, (list, tuple)):
dataset_files = list(dataset_files)
else:
raise ValueError(f"data_loader must contain dataset_dir or dataset_files,"
f"but get {dataset_config.data_loader}.")
logger.info("Using args %s to instance the dataset.", dataset_config.data_loader)
dataset = build_dataset_loader(
dataset_config.data_loader, default_args={'dataset_files': dataset_files,
'num_shards': dataset_config.device_num,
'shard_id': dataset_config.rank_id,
'columns_list': dataset_config.input_columns})
return dataset
[文档] @classmethod
def train_dataset_function(cls, prompt, answer):
"""generates train dataset"""
prompt, answer = prompt.tolist(), answer.tolist()
prompt_ids = cls.tokenizer.encode(text=prompt, add_special_tokens=False)
answer_ids = cls.tokenizer.encode(text=answer, add_special_tokens=False)
if len(prompt_ids) > cls.max_source_length - 1:
prompt_ids = prompt_ids[: cls.max_source_length - 1]
if len(answer_ids) > cls.max_target_length - 2:
answer_ids = answer_ids[: cls.max_target_length - 2]
input_ids = cls.tokenizer.build_inputs_with_special_tokens(prompt_ids, answer_ids)
context_length = input_ids.index(cls.tokenizer.bos_token_id)
mask_position = context_length - 1
label = [-100] * context_length + input_ids[mask_position + 2:] # +1 for logits shift
pad_len = cls.max_seq_length - len(input_ids)
input_ids = input_ids + [cls.tokenizer.pad_token_id] * pad_len
label = label + [cls.tokenizer.pad_token_id] * (pad_len + 1) # +1 for logits shift
if cls.ignore_pad_token_for_loss:
label = [(l if l != cls.tokenizer.pad_token_id else -100) for l in label]
position_ids = cls.create_position_ids(np.array(input_ids))
attention_mask = cls.get_masks(np.array(input_ids))
return input_ids, label, position_ids, attention_mask
[文档] @classmethod
def train_dataset_functionv2(cls, prompt, answer):
"""generates train dataset"""
max_seq_length = cls.max_source_length + cls.max_target_length + 1
prompt, answer = prompt.tolist(), answer.tolist()
history = None
prompt = cls.tokenizer.build_prompt(prompt, history)
prompt_ids = cls.tokenizer.encode(text=prompt, add_special_tokens=True, max_length=cls.max_source_length)
answer_ids = cls.tokenizer.encode(text=answer, add_special_tokens=False, max_length=cls.max_target_length)
if len(prompt_ids) > cls.max_source_length - 1:
prompt_ids = prompt_ids[: cls.max_source_length - 1]
if len(answer_ids) > cls.max_target_length - 2:
answer_ids = answer_ids[: cls.max_target_length - 2]
context_length = len(prompt_ids)
input_ids = prompt_ids + answer_ids + [cls.tokenizer.eos_token_id]
labels = [cls.tokenizer.pad_token_id] * context_length + answer_ids[1:] + [cls.tokenizer.eos_token_id]
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [cls.tokenizer.pad_token_id] * pad_len
labels = labels + [cls.tokenizer.pad_token_id] * (pad_len + 1) # +1 for logits shift
if cls.ignore_pad_token_for_loss:
labels = [(l if l != cls.tokenizer.pad_token_id else -100) for l in labels]
return input_ids, labels
[文档] @classmethod
def eval_dataset_functionv2(cls, prompt, answer):
"""generates eval dataset"""
prompt, answer = prompt.tolist(), answer.tolist()
history = None
prompt = cls.tokenizer.build_prompt(prompt, history)
if len(prompt) > cls.max_source_length - 1:
prompt = prompt[: cls.max_source_length - 1]
if len(answer) > cls.max_target_length - 1:
answer = answer[: cls.max_target_length - 1]
input_ids = cls.tokenizer.encode(text=prompt, add_special_tokens=True, max_length=cls.max_source_length)
label = cls.tokenizer.encode(text=answer, add_special_tokens=True, max_length=cls.max_target_length)
pad_len = cls.max_source_length - len(input_ids)
input_ids = input_ids + [cls.tokenizer.pad_token_id] * pad_len
pad_len = cls.max_target_length - len(label)
label = label + [cls.tokenizer.pad_token_id] * pad_len
return input_ids, label
[文档] @classmethod
def eval_dataset_function(cls, prompt, answer):
"""generates eval dataset"""
prompt, answer = prompt.tolist(), answer.tolist()
if len(prompt) > cls.max_source_length - 2:
prompt = prompt[: cls.max_source_length - 2]
if len(answer) > cls.max_target_length - 2:
answer = answer[: cls.max_target_length - 2]
input_ids = cls.tokenizer.encode(text=prompt, add_special_tokens=True)
label = cls.tokenizer.encode(text=answer, add_special_tokens=True)
pad_len = cls.max_source_length - len(input_ids)
input_ids = input_ids + [cls.tokenizer.pad_token_id] * pad_len
pad_len = cls.max_target_length - len(label)
label = label + [cls.tokenizer.pad_token_id] * pad_len
return input_ids, label
[文档] @classmethod
def get_masks(cls, input_ids, bos_token_id=130004):
"""generate mask from input id"""
seq_length = input_ids.shape[0]
mask = bos_token_id * np.ones(shape=(seq_length), dtype=np.int32)
mask = np.equal(input_ids, mask)
# 要求input_ids中有且仅有一个bos_token_id
context_lengths = np.argwhere(mask)[:, -1]
attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.float32))
for context_length in context_lengths:
attention_mask[:, :context_length] = 1
attention_mask = np.logical_not(attention_mask.astype(np.bool_))
attention_mask = attention_mask.astype(np.float32)
attention_mask = np.expand_dims(attention_mask, 0)
return attention_mask
[文档] @classmethod
def get_position_ids(cls, input_ids, mask_positions, use_gmasks=None,
bos_token_id=130004, position_encoding_2d=True):
"""generate position ids from input id and mask positions"""
seq_length = input_ids.shape[0]
if use_gmasks is None:
use_gmasks = [False]
mask = bos_token_id * np.ones(shape=(seq_length), dtype=np.int32)
mask = np.equal(input_ids, mask)
# 要求input_ids中有且仅有一个bos_token_id
context_lengths = np.argwhere(mask)[:, -1]
if position_encoding_2d:
position_ids = np.arange(seq_length, dtype=np.int64)
for i, context_length in enumerate(context_lengths):
position_ids[context_length:] = mask_positions[i]
block_position_ids = [np.concatenate((
np.zeros(context_length, dtype=np.int64),
np.arange(seq_length - context_length, dtype=np.int64) + 1
)) for context_length in context_lengths]
block_position_ids = np.stack(block_position_ids, axis=0).squeeze()
position_ids = np.stack((position_ids, block_position_ids), axis=0)
else:
position_ids = np.arange(seq_length, dtype=np.int64)
for i, context_length in enumerate(context_lengths):
if not use_gmasks[i]:
position_ids[context_length:] = mask_positions[i]
return position_ids
[文档] @classmethod
def create_position_ids(cls, input_ids, gmask_token_id=130001):
"""generate position ids from input id"""
seq_length = input_ids.shape[0]
seqs = input_ids
# 要求input_ids中, 每行有且仅有一个gMASK
use_gmasks = gmask_token_id * np.ones(shape=(seq_length), dtype=np.int32)
mask = np.equal(seqs, use_gmasks)
mask_positions = np.argwhere(mask)[:, -1]
position_ids = cls.get_position_ids(input_ids, mask_positions=mask_positions, use_gmasks=use_gmasks)
return position_ids