# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TranslationPipeline"""
import os.path
from typing import Union, Optional
import mindspore
from mindspore import Tensor, Model
from ..auto_class import AutoProcessor, AutoModel
from ..mindformer_book import MindFormerBook
from .base_pipeline import BasePipeline
from ..tools.register import MindFormerRegister, MindFormerModuleType
from ..models import BaseModel, BaseTokenizer
__all__ = ['TranslationPipeline']
[文档]@MindFormerRegister.register(MindFormerModuleType.PIPELINE, alias="translation")
class TranslationPipeline(BasePipeline):
r"""Pipeline for Translation
Args:
model (Union[str, BaseModel]): The model used to perform task,
the input could be a supported model name, or a model instance
inherited from BaseModel.
tokenizer (Optional[BaseTokenizer]): A tokenizer (None or Tokenizer)
for text processing.
Raises:
TypeError: If input model and tokenizer's types are not corrected.
ValueError: if the input model is not in support list.
Examples:
>>> from mindformers.pipeline import TranslationPipeline
>>> translator = TranslationPipeline("t5_small")
>>> output = translator("abc")
"""
_support_list = MindFormerBook.get_model_support_list()['t5']
return_name = 'translation'
def __init__(self, model: Union[str, BaseModel, Model],
tokenizer: Optional[BaseTokenizer] = None,
**kwargs):
if isinstance(model, str):
if model in self._support_list or os.path.isdir(model):
if tokenizer is None:
tokenizer = AutoProcessor.from_pretrained(model).tokenizer
model = AutoModel.from_pretrained(model)
if not isinstance(tokenizer, BaseTokenizer):
raise TypeError(f"tokenizer should be inherited from"
f" BaseTokenizer, but got {type(tokenizer)}.")
else:
raise ValueError(f"{model} is not supported by {self.__class__.__name__},"
f"please selected from {self._support_list}.")
if not isinstance(model, (BaseModel, Model)):
raise TypeError(f"model should be inherited from BaseModel or Model, but got type {type(model)}.")
if tokenizer is None:
raise ValueError(f"{self.__class__.__name__}"
" requires for a tokenizer.")
super().__init__(model, tokenizer, **kwargs)
def _sanitize_parameters(self, **pipeline_parameters):
r"""Sanitize Parameters
Args:
pipeline_parameters (Optional[dict]): The parameter dict to be parsed.
"""
preprocess_keys = ['keys']
preprocess_params = {}
for item in preprocess_keys:
if item in pipeline_parameters:
preprocess_params[item] = pipeline_parameters.get(item)
postprocess_params = {}
forward_key_name = ['top_k', 'top_p', 'do_sample', 'eos_token_id', 'repetition_penalty', 'max_length']
forward_kwargs = {}
for item in forward_key_name:
if item in pipeline_parameters:
forward_kwargs[item] = pipeline_parameters.get(item)
return preprocess_params, forward_kwargs, postprocess_params
[文档] def preprocess(self, inputs: Union[str, dict, Tensor],
**preprocess_params):
r"""The Preprocess For Translation
Args:
inputs (Union[str, dict, Tensor]): The text to be classified.
preprocess_params (dict): The parameter dict for preprocess.
Return:
Processed text.
"""
if isinstance(inputs, dict):
keys = preprocess_params.get('keys', None)
default_src_language_name = 'text'
feature_name = keys.get('src_language', default_src_language_name) if keys else default_src_language_name
inputs = inputs[feature_name]
if isinstance(inputs, mindspore.Tensor):
inputs = inputs.asnumpy().tolist()
input_ids = self.tokenizer(inputs, return_tensors=None)["input_ids"]
return {"input_ids": input_ids}
[文档] def forward(self, model_inputs: dict,
**forward_params):
r"""The Forward Process of Model
Args:
inputs (dict): The output of preprocess.
forward_params (dict): The parameter dict for model forward.
"""
forward_params.pop("None", None)
input_ids = model_inputs["input_ids"]
output_ids = self.network.generate(input_ids, **forward_params)
return {"output_ids": output_ids}
[文档] def postprocess(self, model_outputs: dict,
**postprocess_params):
r"""Postprocess
Args:
model_outputs (dict): Outputs of forward process.
Return:
translation results.
"""
outputs = self.tokenizer.decode(model_outputs["output_ids"], skip_special_tokens=True)
return [{self.return_name + '_text': outputs}]