mindformers.models.base_processor 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
BaseProcessor
"""
import os
import shutil

import yaml

from ..mindformer_book import print_path_or_list, MindFormerBook
from .build_processor import build_processor
from .base_tokenizer import BaseTokenizer
from ..tools import logger
from ..tools.register import MindFormerConfig


[文档]class BaseImageProcessor: """ BaseImageProcessor for all image preprocess. Examples: >>> from mindspore.dataset.vision.transforms import CenterCrop >>> from mindformers.models.base_processor import BaseImageProcessor >>> image_resolution = 224 >>> class MyImageProcessor(BaseImageProcessor): ... def __init__(self, image_resolution): ... super(MyImageProcessor, self).__init__(image_resolution=image_resolution) ... self.center_crop = CenterCrop(image_resolution) ... ... def preprocess(self, images, **kwargs): ... res = [] ... for image in images: ... image = self.center_crop(image) ... res.append(image) ... return res ... >>> my_image_processor = MyImageProcessor(image_resolution) >>> output = my_image_processor(image) """ def __init__(self, **kwargs): self.config = {} self.config.update(kwargs) def __call__(self, image_data, **kwargs): """forward process""" return self.preprocess(image_data, **kwargs)
[文档] def preprocess(self, images, **kwargs): """preprocess method""" raise NotImplementedError("Each image processor must implement its own preprocess method")
class BaseAudioProcessor: """ BaseAudioProcessor for all audio preprocess. Examples: >>> from mindspore.dataset.audio import AllpassBiquad >>> from mindformers.models.base_processor import BaseAudioProcessor >>> sample_rate = 44100 >>> central_freq = 200.0 >>> class MyAudioProcessor(BaseAudioProcessor): ... def __init__(self, audio_property): ... super(MyAudioProcessor, self).__init__(sample_rate=sample_rate, central_freq=central_freq ) ... self.all_pass_biquad = AllpassBiquad(44100, 200.0) ... ... def preprocess(self, audio_data, **kwargs): ... res = [] ... for audio in audio_data: ... audio = self.all_pass_biquad(audio) ... res.append(audio) ... return res ... >>> my_audio_processor = MyAudioProcessor(sample_rate, central_freq) >>> output = my_audio_processor(audio) """ def __init__(self, **kwargs): self.config = {} self.config.update(kwargs) def __call__(self, audio_data, **kwargs): """forward process""" return self.preprocess(audio_data, **kwargs) def preprocess(self, audio_data, **kwargs): """preprocess method""" raise NotImplementedError("Each audio processor must implement its own preprocess method")
[文档]class BaseProcessor: """ Base processor Examples: >>> from mindformers.mindformer_book import MindFormerBook >>> from mindformers.models.base_processor import BaseProcessor >>> class MyProcessor(BaseProcessor): ... _support_list = MindFormerBook.get_processor_support_list()['my_model'] ... ... def __init__(self, image_processor=None, audio_processor=None, tokenizer=None, return_tensors='ms'): ... super(MyProcessor, self).__init__( ... image_processor=image_processor, ... audio_processor=audio_processor, ... tokenizer=tokenizer, ... return_tensors=return_tensors) ... >>> myprocessor = MyProcessor(image_processor, audio_processor, tokenizer) >>> output = mynet(image, audio, text) """ _support_list = [] _model_type = 0 _model_name = 1 def __init__(self, **kwargs): self.config = {} self.config.update(kwargs) self.image_processor = kwargs.pop("image_processor", None) self.audio_processor = kwargs.pop("audio_processor", None) self.tokenizer = kwargs.pop("tokenizer", None) self.max_length = kwargs.pop("max_length", None) self.padding = kwargs.pop("padding", False) self.return_tensors = kwargs.pop("return_tensors", None) def __call__(self, image_input=None, text_input=None): """call function""" output = {} if image_input is not None and self.image_processor: if not isinstance(self.image_processor, BaseImageProcessor): raise TypeError(f"feature_extractor should inherit from the BaseImageProcessor," f" but got {type(self.image_processor)}.") image_output = self.image_processor(image_input) output['image'] = image_output if text_input is not None and self.tokenizer: if not isinstance(self.tokenizer, BaseTokenizer): raise TypeError(f"tokenizer should inherited from the BaseTokenizer," f" but got {type(self.tokenizer)}.") # Format the input into a batch if isinstance(text_input, str): text_input = [text_input] text_output = self.tokenizer(text_input, return_tensors=self.return_tensors, max_length=self.max_length, padding=self.padding)["input_ids"] output['text'] = text_output return output
[文档] def save_pretrained(self, save_directory=None, save_name="mindspore_model"): """ Save_pretrained. Args: save_directory (str): a directory to save config yaml save_name (str): the name of save files. """ if save_directory is None: save_directory = MindFormerBook.get_default_checkpoint_save_folder() if not isinstance(save_directory, str) or not isinstance(save_name, str): raise TypeError(f"save_directory and save_name should be a str," f" but got {type(save_directory)} and {type(save_name)}.") if not os.path.exists(save_directory): os.makedirs(save_directory, exist_ok=True) parsed_config = self._inverse_parse_config(self.config) wraped_config = self._wrap_config(parsed_config) config_path = os.path.join(save_directory, save_name + '.yaml') meraged_dict = {} if os.path.exists(config_path): with open(config_path, 'r') as file_reader: meraged_dict = yaml.load(file_reader.read(), Loader=yaml.Loader) file_reader.close() meraged_dict.update(wraped_config) with open(config_path, 'w') as file_pointer: file_pointer.write(yaml.dump(meraged_dict)) file_pointer.close() logger.info("processor saved successfully!")
def _inverse_parse_config(self, config): """ Inverse parse config method, which builds yaml file content for feature extractor config. Args: Config (dict): a dict, which contains input parameters of feature extractor. Returns: A dict, which follows the yaml content. """ parsed_config = {"type": self.__class__.__name__} for key, val in config.items(): if isinstance(val, BaseTokenizer): parsed_sub_config = {"type": val.__class__.__name__} parsed_sub_config.update(val.init_kwargs) parsed_config.update({key: parsed_sub_config}) elif isinstance(val, (BaseImageProcessor, BaseAudioProcessor)): parsed_sub_config = {"type": val.__class__.__name__} parsed_sub_config.update(val.config) parsed_config.update({key: parsed_sub_config}) else: parsed_config.update({key: val}) return parsed_config def _wrap_config(self, config): """ Wrap config function, which wraps a config to rebuild content of yaml file. Args: config (dict): a dict processed by _inverse_parse_config function. Returns: A dict for yaml.dump. """ return {"processor": config}
[文档] @classmethod def from_pretrained(cls, yaml_name_or_path, **kwargs): """ From pretrain method, which instantiates a processor by yaml name or path. Args: yaml_name_or_path (str): A supported yaml name or a path to .yaml file, the supported model name could be selected from .show_support_list(). If yaml_name_or_path is model name, it supports model names beginning with mindspore or the model name itself, such as "mindspore/vit_base_p16" or "vit_base_p16". pretrained_model_name_or_path (Optional[str]): Equal to "yaml_name_or_path", if "pretrained_model_name_or_path" is set, "yaml_name_or_path" is useless. Returns: A processor which inherited from BaseProcessor. """ pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path", None) if pretrained_model_name_or_path is not None: yaml_name_or_path = pretrained_model_name_or_path if not isinstance(yaml_name_or_path, str): raise TypeError(f"yaml_name_or_path should be a str," f" but got {type(yaml_name_or_path)}") is_exist = os.path.exists(yaml_name_or_path) if not is_exist and yaml_name_or_path not in cls._support_list: raise ValueError(f'{yaml_name_or_path} does not exist,' f' or it is not supported by {cls.__name__}. ' f'please select from {cls._support_list}.') if is_exist: logger.info("config in %s is used for processor" " building.", yaml_name_or_path) config_args = MindFormerConfig(yaml_name_or_path) else: yaml_name = yaml_name_or_path if yaml_name_or_path.startswith('mindspore'): # Adaptation the name of yaml at the beginning of mindspore, # the relevant file will be downloaded from the Xihe platform. # such as "mindspore/vit_base_p16" yaml_name = yaml_name_or_path.split('/')[cls._model_name] checkpoint_path = os.path.join(MindFormerBook.get_xihe_checkpoint_download_folder(), yaml_name.split('_')[cls._model_type]) else: # Default the name of yaml, # the relevant file will be downloaded from the Obs platform. # such as "vit_base_p16" checkpoint_path = os.path.join(MindFormerBook.get_default_checkpoint_download_folder(), yaml_name_or_path.split('_')[cls._model_type]) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) yaml_file = os.path.join(checkpoint_path, yaml_name + ".yaml") def get_default_yaml_file(model_name): default_yaml_file = "" for model_dict in MindFormerBook.get_trainer_support_task_list().values(): if model_name in model_dict: default_yaml_file = model_dict.get(model_name) break return default_yaml_file if not os.path.exists(yaml_file): default_yaml_file = get_default_yaml_file(yaml_name) if os.path.realpath(default_yaml_file) and os.path.exists(default_yaml_file): shutil.copy(default_yaml_file, yaml_file) logger.info("default yaml config in %s is used.", yaml_file) else: raise FileNotFoundError(f'default yaml file path must be correct, but get {default_yaml_file}') config_args = MindFormerConfig(yaml_file) processor = build_processor(config_args.processor) logger.info("processor built successfully!") return processor
[文档] @classmethod def show_support_list(cls): """show_support_list""" logger.info("support list of %s is:", cls.__name__) print_path_or_list(cls._support_list)
[文档] @classmethod def get_support_list(cls): """get_support_list method""" return cls._support_list