# Text Classification
## 任务描述
文本分类:模型在基于文本对的微调后,可以在给定任意文本对与候选标签列表的情况下,完成对文本对关系的分类,文本对的两个文本之间以-分割。
[相关论文](https://arxiv.org/pdf/1810.04805.pdf) Jacob Devlin, Ming-Wei Chang, et al., BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, 2019.
## 已支持数据集性能
| model | type | datasets | Top1-accuracy | stage | example |
|:------:|:----------------------------------------------------------:|:--------:|:---------------:|:--------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| bert | txtcls_bert_base_uncased | Mnli | 30.9% | pretrain | -- |
| txtcls | txtcls_bert_case_uncased_mnli | Mnli | 84.8% | finetune
eval
predict | [link](../../examples/text_classification/txtcls_bert_base_uncased_finetune_on_mnli.sh)
[link](../../examples/text_classification/txtcls_bert_base_uncased_mnli_eval_on_mnli.sh)
[link](../../examples/text_classification/txtcls_bert_base_uncased_mnli_predict_on_mnli.sh) |
### [Mnli](https://dl.fbaipublicfiles.com/glue/data/MNLI.zip)
- 数据集大小:298M,共431992个样本,3个类别
- 训练集:392702个样本
- 匹配测试集:9796个样本
- 非匹配测试集:9847个样本
- 匹配开发集:9815个样本
- 非匹配开发集:9832个样本
- 数据格式:tsv文件
```bash
数据集目录格式
└─mnli
├─dev
├─test
└─train
```
- 用户可以参考[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件,进行Mnli数据集`TFRecord`格式文件的生成。
## 快速任务接口
### 脚本启动
> 需开发者提前clone工程。
- 请参考[使用脚本启动](https://gitee.com/mindspore/transformer/blob/master/README.md#%E6%96%B9%E5%BC%8F%E4%B8%80clone-%E5%B7%A5%E7%A8%8B%E4%BB%A3%E7%A0%81)
- 脚本运行测试
```shell
# finetune
python run_mindformer.py --config ./configs/txtcls/run_txtcls_bert_base_uncased.yaml --run_mode finetune --load_checkpoint txtcls_bert_base_uncased
# evaluate
python run_mindformer.py --config ./configs/txtcls/run_txtcls_bert_base_uncased.yaml --run_mode eval --load_checkpoint txtcls_bert_base_uncased_mnli
# predict
python run_mindformer.py --config ./configs/txtcls/run_txtcls_bert_base_uncased.yaml --run_mode predict --load_checkpoint txtcls_bert_base_uncased_mnli --predict_data [TEXT]
```
### 调用API启动
- Trainer接口开启训练/评估/推理:
```python
from mindformers import MindFormerBook
from mindformers.trainer import Trainer
# 显示Trainer的模型支持列表
MindFormerBook.show_trainer_support_model_list("text_classification")
# INFO - Trainer support model list for txt_classification task is:
# INFO - ['txtcls_bert_base_uncased']
# INFO - -------------------------------------
# 初始化trainer
trainer = Trainer(task='text_classification',
model='txtcls_bert_base_uncased',
train_dataset='./mnli/train',
eval_dataset='./mnli/eval')
# 测试数据,该input_data有两个测试案例,即两个文本对,单个文本对的两个文本之间用-分割
input_data = ["The new rights are nice enough-Everyone really likes the newest benefits ",
"i don't know um do you do a lot of camping-I know exactly."]
#方式1:使用现有的预训练权重进行finetune, 并使用finetune获得的权重进行eval和推理
trainer.train(resume_or_finetune_from_checkpoint="txtcls_bert_base_uncased",
do_finetune=True)
trainer.evaluate(eval_checkpoint=True)
trainer.predict(predict_checkpoint=True, input_data=input_data, top_k=1)
# 方式2: 从obs下载训练好的权重并进行eval和推理
trainer.evaluate()
# INFO - Top1 Accuracy=84.8%
trainer.predict(input_data=input_data, top_k=1)
# INFO - output result is [[{'label': 'neutral', 'score': 0.9714198708534241}],
# [{'label': 'contradiction', 'score': 0.9967639446258545}]]
```
- pipeline接口开启快速推理
```python
from mindformers.pipeline import TextClassificationPipeline
from mindformers import AutoTokenizer, BertForMultipleChoice, AutoConfig
input_data = ["The new rights are nice enough-Everyone really likes the newest benefits ",
"i don't know um do you do a lot of camping-I know exactly."]
tokenizer = AutoTokenizer.from_pretrained('txtcls_bert_base_uncased_mnli')
txtcls_mnli_config = AutoConfig.from_pretrained('txtcls_bert_base_uncased_mnli')
# Because batch_size parameter is required when bert model is created, and pipeline
# function deals with samples one by one, the batch_size parameter is seted one.
txtcls_mnli_config.batch_size = 1
model = BertForMultipleChoice(txtcls_mnli_config)
txtcls_pipeline = TextClassificationPipeline(task='text_classification',
model=model,
tokenizer=tokenizer,
max_length=model.config.seq_length,
padding="max_length")
results = txtcls_pipeline(input_data, top_k=1)
print(results)
# 输出
# [[{'label': 'neutral', 'score': 0.9714198708534241}], [{'label': 'contradiction', 'score': 0.9967639446258545}]]
```