mindformers.pipeline.base_pipeline 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file was refer to project:
# https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/base.py
# ============================================================================

"""BasePipeline"""
from abc import ABC, abstractmethod
from typing import Optional, Union
import numpy as np

from tqdm import tqdm
from mindspore import Tensor, Model
from mindspore.dataset import (
    GeneratorDataset, VisionBaseDataset,
    SourceDataset, MappableDataset
)
from mindspore.dataset.engine.datasets import BatchDataset, RepeatDataset

from mindformers.tools import logger
from mindformers.mindformer_book import print_dict
from ..auto_class import AutoModel
from ..models import BaseModel, BaseTokenizer, BaseImageProcessor


[文档]class BasePipeline(ABC): r""" Base Pipeline For All Task Pipelines Args: model (Union[str, BaseModel]): The model used to perform task, the input could be a supported model name, or a model instance inherited from BaseModel. tokenizer (Optional[BaseTokenizer]): The tokenizer of model, it could be None if the model do not need tokenizer. image_processor (Optional[BaseImageProcessor]): The image_processor of model, it could be None if the model do not need image_processor. """ _support_list = {} def __init__(self, model: Union[str, BaseModel, Model], tokenizer: Optional[BaseTokenizer] = None, image_processor: Optional[BaseImageProcessor] = None, **kwargs): super(BasePipeline, self).__init__() self.model = model if isinstance(model, str) and model in self._support_list: self.network = AutoModel.from_pretrained(model) elif isinstance(model, BaseModel): self.network = model elif isinstance(model, Model): self.network = model.predict_network else: raise TypeError(f"model should be str or inherited from BaseModel or Model, but got type {type(model)}.") self.tokenizer = tokenizer self.image_processor = image_processor self._preprocess_params, self._forward_params, \ self._postprocess_params = self._sanitize_parameters(**kwargs) self.call_count = 0 self._batch_size = kwargs.pop("batch_size", None) def __call__(self, inputs: Union[GeneratorDataset, list, str, Tensor, np.array], batch_size: Optional[int] = None, **kwargs): r"""Call Method Args: inputs (Union[GeneratorDataset, list, str, etc]): The inputs of pipeline, the type of inputs depends on task. batch_size (Optional[int]): The batch size for a GeneratorDataset input, for other types of inputs, the batch size would be set to 1 by default. """ preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(**kwargs) preprocess_params = {**self._preprocess_params, **preprocess_params} forward_params = {**self._forward_params, **forward_params} postprocess_params = {**self._postprocess_params, **postprocess_params} is_dataset = isinstance(inputs, ( GeneratorDataset, VisionBaseDataset, MappableDataset, SourceDataset)) is_list = isinstance(inputs, list) self.call_count += 1 if self.call_count == 20 and not is_dataset: logger.info("You seem to be using the pipeline sequentially for" " numerous samples. In order to maximize efficiency" " please set input a mindspore.dataset.GeneratorDataset.") if batch_size is None: if self._batch_size is None: batch_size = 1 else: batch_size = self._batch_size if batch_size > 1 and not is_dataset: batch_size = 1 logger.info("batch_size is set to 1 for non-dataset inputs.") if is_dataset: logger.info("dataset is processing.") if not isinstance(inputs, (BatchDataset, RepeatDataset)): inputs = inputs.batch(batch_size) outputs = [] for items in tqdm(inputs.create_dict_iterator()): outputs.extend(self.run_single(items, preprocess_params, forward_params, postprocess_params)) elif is_list: outputs = self.run_multi(inputs, preprocess_params, forward_params, postprocess_params) else: outputs = self.run_single(inputs, preprocess_params, forward_params, postprocess_params) return outputs @abstractmethod def _sanitize_parameters(self, **pipeline_parameters): r"""Sanitize Parameters Args: pipeline_parameters (Optional[dict]): The parameter dict to be parsed. Raises: NotImplementedError: If the method is not implemented. """ raise NotImplementedError("_sanitize_parameters not implemented")
[文档] def run_single(self, inputs: Union[dict, str, np.array, Tensor], preprocess_params: dict, forward_params: dict, postprocess_params: dict): r"""Run Single method This function is used to run a single forward process for task. Args: inputs (Union[dict, str, etc]): The inputs of pipeline, the type of inputs depends on task. preprocess_params (dict): The parameter dict for preprocess. forward_params (dict): The parameter dict for model forward process. postprocess_params (dict): The parameter dict for postprocess. """ model_inputs = self.preprocess(inputs, **preprocess_params) model_outputs = self.forward(model_inputs, **forward_params) outputs = self.postprocess(model_outputs, **postprocess_params) return outputs
[文档] def run_multi(self, inputs: Union[list, tuple], preprocess_params: dict, forward_params: dict, postprocess_params: dict): r"""Run Multiple Method This function is used to run a list input for task. Args: inputs (Union[list, tuple, iterator]): The iterable input for pipeline. preprocess_params (dict): The parameter dict for preprocess. forward_params (dict): The parameter dict for model forward process. postprocess_params (dict): The parameter dict for postprocess. """ outputs = [] for item in inputs: outputs.extend(self.run_single(item, preprocess_params, forward_params, postprocess_params)) return outputs
[文档] @abstractmethod def preprocess(self, inputs: Union[dict, str, np.array, Tensor], **preprocess_params): r"""The Preprocess For Task Args: inputs (Union[dict, str, etc]): The inputs of pipeline, the type of inputs depends on task. preprocess_params (dict): The parameter dict for preprocess. Raises: NotImplementedError: If the method is not implemented. """ raise NotImplementedError("preprocess not implemented.")
[文档] @abstractmethod def forward(self, model_inputs: Union[dict, str, np.array, Tensor], **forward_params): r"""The Forward Process of Model Args: model_inputs (Union[dict, str, etc]): The output of preprocess, the type of model_inputs depends on task. forward_params (dict): The parameter dict for model forward. Raises: NotImplementedError: If the method is not implemented. """ raise NotImplementedError("forward not implemented.")
[文档] @abstractmethod def postprocess(self, model_outputs: Union[dict, str, np.array, Tensor], **postprocess_params): r"""The Postprocess of Task Args: model_outputs (Union[dict, str, etc]): The output of model forward, the type of model_outputs depends on task. postprocess_params (dict): The parameter dict for post process. Raises: NotImplementedError: If the method is not implemented. """ raise NotImplementedError("postprocess not implemented.")
[文档] @classmethod def show_support_list(cls): """show_support_list""" logger.info("support list of %s is:", cls.__name__) print_dict(cls._support_list)