mindformers.trainer.ZeroShotImageClassificationTrainer¶
- class mindformers.trainer.ZeroShotImageClassificationTrainer(model_name: Optional[str] = None)[源代码]¶
ZeroShotImageClassification Task For Trainer.
- Args:
model_name (str): The model name of Task-Trainer. Default: None
- Raises:
NotImplementedError: If train method or evaluate method or predict method not implemented.
- Examples:
>>> from mindformers import ZeroShotImageClassificationTrainer >>> trainer = ZeroShotImageClassificationTrainer(model_name="clip_vit_b_32") >>> trainer.evaluate() >>> trainer.predict()
- evaluate(config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None, network: Optional[Union[Cell, BaseModel]] = None, dataset: Optional[Union[BaseDataset, GeneratorDataset]] = None, callbacks: Optional[Union[Callback, List[Callback]]] = None, compute_metrics: Optional[Union[dict, set]] = None, **kwargs)[源代码]¶
Evaluate task for ZeroShotImageClassification Trainer. This function is used to evaluate the network.
The trainer interface is used to quickly start training for general task. It also allows users to customize the network, dataset, callbacks, compute_metrics.
- Args:
- config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]):
The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.
- network (Optional[Union[Cell, BaseModel]]): The network for trainer.
It supports model name or BaseModel or MindSpore Cell class. Default: None.
- dataset (Optional[Union[BaseDataset, GeneratorDataset]]): The evaluate dataset.
It supports real dataset path or BaseDateset class or MindSpore Dataset class. Default: None.
- callbacks (Optional[Union[Callback, List[Callback]]]): The training callback function.
It supports CallBack or CallBack List of MindSpore. Default: None.
- compute_metrics (Optional[Union[dict, set]]): The metric of evaluating.
It supports dict or set in MindSpore’s Metric class. Default: None.
- predict(config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None, input_data: Optional[Union[GeneratorDataset, Tensor, ndarray, Image, str, list]] = None, network: Optional[Union[Cell, BaseModel]] = None, tokenizer: Optional[BaseTokenizer] = None, image_processor: Optional[BaseImageProcessor] = None, **kwargs)[源代码]¶
Predict task for ZeroShotImageClassification Trainer. This function is used to predict the network.
- Args:
- config (Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]]):
The task config which is used to configure the dataset, the hyper-parameter, optimizer, etc. It supports config dict or MindFormerConfig or TrainingArguments or ConfigArguments class. Default: None.
- network (Optional[Union[Cell, BaseModel]]): The network for trainer.
It supports model name or BaseModel or MindSpore Cell class. Default: None.
- input_data (Optional[Union[GeneratorDataset, Tensor, np.ndarray, Image, str, list]]):
The dataset. It supports real dataset path or BaseDateset class or MindSpore Dataset class. Default: None.
tokenizer (Optional[BaseTokenizer]): Used for text process. image_processor (Optional[BaseImageProcessor]): Used for image process.