mindformers.models.base_config 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
BaseConfig class,
which is all model configs' base class
"""
import os
import shutil

import yaml

from ..mindformer_book import MindFormerBook
from ..mindformer_book import print_path_or_list
from ..tools import logger
from ..tools.register.config import MindFormerConfig
from ..models.build_config import build_model_config


[文档]class BaseConfig(dict): """ Base Config for all models' config Examples: >>> from mindformers.mindformer_book import MindFormerBook >>> from mindformers.models.base_config import BaseConfig >>> class MyConfig(BaseConfig): ... _support_list = MindFormerBook.get_config_support_list()['my_model'] ... ... def __init__(self, ... data_size: int = 32, ... net_size: list = [1, 2, 3], ... loss_type: str = "MyLoss", ... checkpoint_name_or_path: str = '', ... **kwargs): ... self.data_size = data_size ... self.net_size = net_size ... loss_type = loss_type ... self.checkpoint_name_or_path = checkpoint_name_or_path ... super(MyConfig, self).__init__(**kwargs) ... >>> mynet = MyModel(MyConfig) >>> output = mynet(input) """ _support_list = [] _model_type = 0 _model_name = 1 def __init__(self, **kwargs): super(BaseConfig, self).__init__() self.update(kwargs) def __getattr__(self, key): if key not in self: return None return self[key] def __setattr__(self, key, value): self[key] = value def __delattr__(self, key): del self[key]
[文档] def to_dict(self): """ for yaml dump, transform from Config to a strict dict class """ return_dict = {} for key, val in self.items(): if isinstance(val, BaseConfig): val = val.to_dict() return_dict[key] = val return return_dict
[文档] @classmethod def from_pretrained(cls, yaml_name_or_path, **kwargs): """ From pretrain method, which instantiates a config by yaml name or path. Args: yaml_name_or_path (str): A supported model name or a path to model config (.yaml), the supported model name could be selected from AutoConfig.show_support_list(). If yaml_name_or_path is model name, it supports model names beginning with mindspore or the model name itself, such as "mindspore/vit_base_p16" or "vit_base_p16". pretrained_model_name_or_path (Optional[str]): Equal to "yaml_name_or_path", if "pretrained_model_name_or_path" is set, "yaml_name_or_path" is useless. Returns: A model config, which inherited from BaseConfig. """ pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path", None) if pretrained_model_name_or_path is not None: yaml_name_or_path = pretrained_model_name_or_path if not isinstance(yaml_name_or_path, str): raise TypeError(f"yaml_name_or_path should be a str," f" but got {type(yaml_name_or_path)}.") if os.path.exists(yaml_name_or_path): if not yaml_name_or_path.endswith(".yaml"): raise ValueError(f"{yaml_name_or_path} should be a .yaml file for model" " config.") config_args = MindFormerConfig(yaml_name_or_path) logger.info("the content in %s is used for" " config building.", yaml_name_or_path) elif yaml_name_or_path not in cls._support_list: raise ValueError(f"{yaml_name_or_path} is not a supported" f" model type or a valid path to model config." f" supported model could be selected from {cls._support_list}.") else: yaml_name = yaml_name_or_path if yaml_name_or_path.startswith('mindspore'): # Adaptation the name of yaml at the beginning of mindspore, # the relevant file will be downloaded from the Xihe platform. # such as "mindspore/vit_base_p16" yaml_name = yaml_name_or_path.split('/')[cls._model_name] checkpoint_path = os.path.join(MindFormerBook.get_xihe_checkpoint_download_folder(), yaml_name.split('_')[cls._model_type]) else: # Default the name of yaml, # the relevant file will be downloaded from the Obs platform. # such as "vit_base_p16" checkpoint_path = os.path.join(MindFormerBook.get_default_checkpoint_download_folder(), yaml_name_or_path.split('_')[cls._model_type]) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) yaml_file = os.path.join(checkpoint_path, yaml_name + ".yaml") def get_default_yaml_file(model_name): default_yaml_file = "" for model_dict in MindFormerBook.get_trainer_support_task_list().values(): if model_name in model_dict: default_yaml_file = model_dict.get(model_name) break return default_yaml_file if not os.path.exists(yaml_file): default_yaml_file = get_default_yaml_file(yaml_name) if os.path.realpath(default_yaml_file) and os.path.exists(default_yaml_file): shutil.copy(default_yaml_file, yaml_file) logger.info("default yaml config in %s is used.", yaml_file) else: raise FileNotFoundError(f'default yaml file path must be correct, but get {default_yaml_file}') config_args = MindFormerConfig(yaml_file) config = build_model_config(config_args.model.model_config) MindFormerBook.set_model_config_to_name(id(config), config_args.model.arch.type) return config
[文档] def save_pretrained(self, save_directory=None, save_name="mindspore_model"): """ Save_pretrained. Args: save_directory (str): a directory to save config yaml save_name (str): the name of save files. """ if save_directory is None: save_directory = MindFormerBook.get_default_checkpoint_save_folder() if not isinstance(save_directory, str) or not isinstance(save_name, str): raise TypeError(f"save_directory and save_name should be a str," f" but got {type(save_directory)} and {type(save_name)}.") if not os.path.exists(save_directory): os.makedirs(save_directory, exist_ok=True) save_path = os.path.join(save_directory, save_name + ".yaml") parsed_config, removed_list = self._inverse_parse_config() wraped_config = self._wrap_config(parsed_config) for key, val in removed_list: self[key] = val self.remove_type() meraged_dict = {} if os.path.exists(save_path): with open(save_path, 'r') as file_reader: meraged_dict = yaml.load(file_reader.read(), Loader=yaml.Loader) file_reader.close() meraged_dict.update(wraped_config) with open(save_path, 'w') as file_pointer: file_pointer.write(yaml.dump(meraged_dict)) file_pointer.close() logger.info("config saved successfully!")
[文档] def remove_type(self): """remove type caused by save’""" if isinstance(self, BaseConfig): self.pop("type") for key, val in self.items(): if isinstance(val, BaseConfig): val.pop("type") self.update({key: val})
[文档] def inverse_parse_config(self): """inverse_parse_config""" val, _ = self._inverse_parse_config() return val
def _inverse_parse_config(self): """ Inverse parse config method, which builds yaml file content for model config. Returns: A model config, which follows the yaml content. """ self.update({"type": self.__class__.__name__}) removed_list = [] for key, val in self.items(): if isinstance(val, BaseConfig): val = val.inverse_parse_config() elif not isinstance(val, (str, int, float, bool)): removed_list.append((key, val)) continue self.update({key: val}) for key, _ in removed_list: self.pop(key) return self, removed_list def _wrap_config(self, config): """ Wrap config function, which wraps a config to rebuild content of yaml file. Args: config (BaseConfig): a config processed by _inverse_parse_config function. Returns: A (config) dict for yaml.dump. """ model_name = self.pop("model_name", None) if model_name is None: model_name = MindFormerBook.get_model_config_to_name().get(id(config), None) return {"model": {"model_config": config.to_dict(), "arch": {"type": model_name}}}
[文档] @classmethod def show_support_list(cls): """show support list method""" logger.info("support list of %s is:", cls.__name__) print_path_or_list(cls._support_list)
[文档] @classmethod def get_support_list(cls): """get support list method""" return cls._support_list