mindformers.models.mae.mae_processor 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ViTMAEProcessor
"""
import numpy as np
from PIL import Image

from mindspore import Tensor
from mindspore.dataset.vision.transforms import ToTensor, Normalize

from mindformers.mindformer_book import MindFormerBook
from mindformers.dataset import Resize
from mindformers.dataset.base_dataset import BaseDataset
from mindformers.models.base_processor import BaseProcessor, BaseImageProcessor
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType


__all__ = ['ViTMAEProcessor', 'ViTMAEImageProcessor']


[文档]@MindFormerRegister.register(MindFormerModuleType.PROCESSOR) class ViTMAEImageProcessor(BaseImageProcessor): """ ViTMAEImageProcessor. Args: image_resolution (int): the target size. """ def __init__(self, size=224, patch_size=16, mask_ratio=0.75): super(ViTMAEImageProcessor, self).__init__(image_resolution=size) self.resize = Resize((size, size), interpolation='cubic') self.to_tensor = ToTensor() self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], is_hwc=False) if not 0 < mask_ratio < 1: raise ValueError('masking ratio must be kept between 0 and 1, but get mask_ratio {mask_ratio}.') # seq_length self.num_patches = (size // patch_size) ** 2 # seq masked number self.keep_num = int((1 - mask_ratio) * self.num_patches)
[文档] def preprocess(self, images, **kwargs): """ Preprocess required by base processor. Args: images (tensor, PIL.Image, numpy.array, list): a batch of images. Return: A 4-rank tensor for a batch of images. """ images = self._format_inputs(images) res = [] ids_restores = [] masks = [] unmask_indexes = [] for image in images: image = self.resize(image) image = self.to_tensor(image) image = self.normalize(image) res.append(image) rand_indices = np.argsort( np.random.uniform(size=(self.num_patches,)), axis=0).astype(np.int32) ids_restore = np.argsort(rand_indices, axis=0).astype(np.int32) ids_restores.append(ids_restore) mask = np.ones((self.num_patches,)).astype(np.int32) mask[:self.keep_num] = 0 masks.append(mask) unmask_index = rand_indices[:self.keep_num] unmask_indexes.append(unmask_index) return Tensor(res), Tensor(masks), Tensor(ids_restores), Tensor(unmask_indexes)
def _format_inputs(self, inputs): """ Transform image classification inputs into (bz, h, w, c) or (h, w, c) numpy array. Args: inputs (tensor, numpy.array, PIL.Image, list, BaseDataset): for numpy or tensor input, the channel could be (bz, c, h, w), (c, h, w) or (bz, h, w, c), (h, w, c); for list, the item could be PIL.Image, numpy.array, Tensor; for BaseDataset, return without any operations. Return: transformed images: for PIL.Image, numpy or tensor input, return a numpy array, the channel is (bz, h, w, c) or (h, w, c); for list, return a numpy array for each element; for BaseDataset, it is returned directly. """ if not isinstance(inputs, (list, Image.Image, Tensor, np.ndarray, BaseDataset)): raise TypeError("input type is not Tensor, numpy, Image, list of Image or MindFormer BaseDataset") if isinstance(inputs, list): return [self._format_inputs(item) for item in inputs] if isinstance(inputs, Image.Image): inputs = np.array(inputs) if isinstance(inputs, Tensor): inputs = inputs.asnumpy() if isinstance(inputs, np.ndarray): if len(inputs.shape) == 3: inputs = np.expand_dims(inputs, 0) inputs = self._chw2hwc(inputs) elif len(inputs.shape) == 4: inputs = self._chw2hwc(inputs) else: raise ValueError(f"the rank of image_batch should be 3 or 4," f" but got {len(inputs.shape)}") return inputs @staticmethod def _chw2hwc(inputs): if inputs.shape[-1] != 3: inputs = inputs.transpose(0, 2, 3, 1) return inputs
[文档]@MindFormerRegister.register(MindFormerModuleType.PROCESSOR) class ViTMAEProcessor(BaseProcessor): """ ViTMAEProcessor, consists of a feature extractor (BaseFeatureEXtractor) for image input. Examples: >>> from mindformers import MindFormerBook >>> from mindformers.models import ViTMAEProcessor >>> yaml_path = os.path.join(MindFormerBook.get_project_path(), "configs", ... "mae", "model_config", "mae_vit_base_p16.yaml") >>> # build ViTMAEProcessor from pretrained >>> pro_a = ViTMAEProcessor.from_pretrained('mae_vit_base_p16') >>> # build ViTMAEProcessor from config >>> pro_b = ViTMAEProcessor.from_pretrained(yaml_path) """ _support_list = MindFormerBook.get_processor_support_list()['mae'] def __init__(self, image_processor=None, return_tensors='ms'): super(ViTMAEProcessor, self).__init__( image_processor=image_processor, return_tensors=return_tensors )