# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file was refer to project:
# https://github.com/facebookresearch/mae
# ============================================================================
"""ViT Model."""
import math
import numpy as np
from mindspore import load_param_into_net, Parameter, nn
from mindspore import ops as P
from mindspore import dtype as mstype
import mindspore.common.initializer as weight_init
from mindformers.mindformer_book import MindFormerBook
from mindformers.core.loss import build_loss
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.models.base_model import BaseModel
from mindformers.models.vit.vit_modules import Block, LayerNorm, Linear, Dropout, PixelShuffle
from mindformers.models.vit.vit_modules import PatchEmbed
from mindformers.models.vit.vit_config import ViTConfig
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS)
class ViTModel(BaseModel):
"""
Vision Transformer with support for patch or hybrid CNN input stage.
The supported model name could be selected from ViTConfig.show_support_list().
Args:
config (ViTConfig): the config of Vit model.
Examples:
>>> # input model name, load model and weights
>>> model_a = ViTModel.from_pretrained('vit_base_p16')
>>> # input config, load model without weights
>>> from mindformers import AutoConfig
>>> config = AutoConfig.from_pretrained('vit_base_p16')
>>> model_b = ViTModel(config)
"""
_support_list = MindFormerBook.get_model_support_list()['vit']
def __init__(self, config=None):
config = config if config else ViTConfig()
super().__init__(config)
self.use_moe = (config.moe_config.expert_num > 1)
parallel_config = config.parallel_config
dp = parallel_config.data_parallel
self.global_pool = config.use_mean_pooling
self.patch_embed = PatchEmbed(img_size=config.image_size, patch_size=config.patch_size,
in_features=config.in_chans, out_features=config.embed_dim,
parallel_config=parallel_config)
self.cls_tokens = Parameter(
weight_init.initializer(weight_init.TruncatedNormal(sigma=config.initializer_range),
(1, 1, config.embed_dim)), requires_grad=True)
num_patches = self.patch_embed.num_patches
seq_length = num_patches + 1
self.seq_length = seq_length
self.num_patches = num_patches
self.num_masked = num_patches - seq_length + 1
self.pos_embed = Parameter(
weight_init.initializer(weight_init.TruncatedNormal(sigma=config.initializer_range),
(1, seq_length, config.embed_dim)), requires_grad=True)
# stochastic depth decay rule
hdr = [x.item() for x in np.linspace(0, config.drop_path_rate, config.depth)]
parallel_config_args = parallel_config.moe_parallel_config if self.use_moe else parallel_config.dp_mp_config
self.blocks = nn.CellList([
Block(hidden_size=config.embed_dim,
ffn_hidden_size=config.intermediate_size,
seq_length=seq_length,
drop_rate=config.drop_rate,
attention_dropout_rate=config.attention_dropout_rate,
hidden_dropout_rate=hdr[i],
layer_norm_eps=config.layer_norm_eps,
qkv_bias=config.qkv_bias,
init_values=config.init_values,
weight_init='XavierUniform',
layernorm_compute_type=config.layernorm_compute_type,
softmax_compute_type=config.softmax_compute_type,
window_size=None,
num_heads=config.num_heads,
hidden_act=config.hidden_act,
post_layernorm_residual=config.post_layernorm_residual,
param_init_type=config.param_init_type,
parallel_config=parallel_config_args)
for i in range(config.depth)])
self.add = P.Add().shard(((dp, 1, 1), (1, 1, 1)))
self.cast = P.Cast()
self.tile = P.Tile().shard(((dp, 1, 1),))
self.cat = P.Concat(axis=1)
self.fc_norm = LayerNorm((config.embed_dim,), eps=1e-6).shard(((dp, 1, 1),))
self.reduce_mean = P.ReduceMean().shard(((dp, 1, 1),))
self.dropout = Dropout(keep_prob=(1. - config.drop_rate))
self.dropout.shard(((dp, 1, 1),))
self.stride_slice = P.StridedSlice().shard(((dp, 1, 1),))
self.init_weights_vit()
self.fix_init_weight()
[文档] def fix_init_weight(self):
"""fix init weight"""
def rescale(param, layer_id):
values = param.data / (math.sqrt(2.0 * layer_id))
param.set_data(values)
for layer_id, block in enumerate(self.blocks):
if self.use_moe:
rescale(block.attention.projection.weight, layer_id + 1)
rescale(block.output.ffn.projection.weight, layer_id + 1)
else:
rescale(block.attention.projection.weight, layer_id + 1)
rescale(block.output.projection.weight, layer_id + 1)
[文档] def init_weights_vit(self):
"""init weights vit
ViT weight initialization, original timm impl (for reproducibility) """
for name, cell in self.cells_and_names():
if isinstance(cell, Linear):
cell.weight.set_data(weight_init.initializer(
weight_init.TruncatedNormal(sigma=self.config.initializer_range),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, Linear) and cell.bias is not None:
cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (LayerNorm, nn.LayerNorm)):
cell.gamma.set_data(weight_init.initializer(weight_init.One(),
cell.gamma.shape,
cell.gamma.dtype))
cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
cell.beta.shape,
cell.beta.dtype))
if name == "patch_embed.proj":
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
def no_weight_decay(self):
return {'pos_embed', 'cls_tokens'}
def load_pretrained(self, params_dict):
return load_param_into_net(self, params_dict)
[文档] def construct_without_pool(self, image, mask=None):
"""construct of vit without pool"""
tokens = self.patch_embed(image, mask)
batch_size = image.shape[0]
cls_tokens = self.tile(self.cls_tokens, (batch_size, 1, 1))
tokens = self.cat((cls_tokens, tokens))
if self.pos_embed is not None:
tokens = self.add(tokens, self.pos_embed)
x = self.dropout(tokens)
encoder_input_mask = P.Ones()((batch_size, self.seq_length, self.seq_length), mstype.float32)
for block in self.blocks:
x = block(x, encoder_input_mask)
return x
def construct(self, image):
"""construct of vit"""
x = self.construct_without_pool(image)
b, s, c = x.shape
if self.global_pool:
x = self.stride_slice(
x, (0, 1, 0), (b, s, c), (1, 1, 1)
)
x = self.reduce_mean(x, 1)
out = self.fc_norm(x)
else:
out = self.stride_slice(
x, (0, 0, 0), (b, 1, c), (1, 1, 1)
)
return out
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS)
class ViTForImageClassification(BaseModel):
"""
Vision Transformer with support for patch or hybrid CNN input stage.
The supported model name could be selected from ViTConfig.show_support_list().
Args:
config (ViTConfig): the config of Vit model.
Examples:
>>> # input model name, load model and weights
>>> model_a = ViTForImageClassification.from_pretrained('vit_base_p16')
>>> # input config, load model without weights
>>> from mindformers import AutoConfig
>>> config = AutoConfig.from_pretrained('vit_base_p16')
>>> model_b = ViTForImageClassification(config)
"""
_support_list = MindFormerBook.get_model_support_list()['vit']
def __init__(self, config=None):
config = config if config else ViTConfig()
super().__init__(config)
self.vit = ViTModel(config)
self.head = Linear(
config.embed_dim, config.num_classes,
weight_init=weight_init.TruncatedNormal(sigma=2e-5),
compute_dtype=mstype.float32).to_float(mstype.float32)
self.loss = build_loss(class_name=config.loss_type)
self.load_checkpoint(config)
def construct(self, image, target=None):
"""construct of vit"""
out = self.vit(image)
out = self.head(out)
if not self.training:
return out, target
loss = self.loss(out, target)
return loss
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class ViTForMaskedImageModeling(BaseModel):
"""
Vision Transformer with support for patch or hybrid CNN input stage.
The supported model name could be selected from ViTConfig.show_support_list().
Args:
config (ViTConfig): the config of Vit model.
Examples:
>>> # input model name, load model and weights
>>> model_a = ViTForMaskedImageModeling.from_pretrained('vit_base_p16')
>>> # input config, load model without weights
>>> from mindformers import AutoConfig
>>> config = AutoConfig.from_pretrained('vit_base_p16')
>>> model_b = ViTForMaskedImageModeling(config)
"""
_support_list = MindFormerBook.get_model_support_list()['vit']
def __init__(self, config=None):
config = config if config else ViTConfig()
super().__init__(config)
self.vit = ViTModel(config)
self.vit.patch_embed = PatchEmbed(img_size=config.image_size, patch_size=config.patch_size,
in_features=config.in_chans, out_features=config.embed_dim,
use_mask=True, parallel_config=config.parallel_config)
self.decoder = nn.CellList(
nn.Conv2d(
in_channels=config.embed_dim,
out_channels=config.encoder_stride ** 2 * config.in_chans,
kernel_size=1,
),
PixelShuffle(config.encoder_stride),
)
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.expand_dims = P.ExpandDims()
self.l1_loss = nn.L1Loss(reduction='none')
# Initialize weights and apply final processing
self.init_weights_vit()
def init_weights_vit(self):
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
def construct(self, image, mask=None):
"""construct of vit for MIM"""
x = self.vit.construct_without_pool(image)
b, s, c = x.shape
height = width = math.floor(s ** 0.5)
x = self.reshape(self.transpose(x, (0, 2, 1)), (b, c, height, width))
reconstruct_images = self.decoder(x)
if not self.training:
return reconstruct_images
size = self.config.image_size // self.config.patch_size
mask = self.reshape(mask, (-1, size, size))
mask = P.repeat_elements(mask, self.config.patch_size, 1)
mask = P.repeat_elements(mask, self.config.patch_size, 2)
mask = self.expand_dims(mask, 1)
reconstruction_loss = self.l1_loss(image, reconstruct_images)
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.in_chans
return masked_im_loss