import torch from transformers.file_utils import is_tf_available, is_torch_available from transformers import BertTokenizerFast, BertForSequenceClassification from transformers import Trainer, TrainingArguments import numpy as np import random from sklearn.datasets import fetch_20newsgroups from sklearn.model_selection import train_test_split
设置seed
通过设置seed可以在不同次的运行当中得到相同的结果,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defset_seed(seed: int):
random.seed(seed) np.random.seed(seed) if is_torch_available(): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available if is_tf_available(): import tensorflow as tf
training_args = TrainingArguments( output_dir='./results', # 输出目录 num_train_epochs=1, # 训练的轮数 per_device_train_batch_size=15, # 训练批次大小为8 per_device_eval_batch_size=15, # 评估批次大小为20 #warmup_steps=30, # 预热步数是指在训练开始阶段,学习率逐渐增加的步数。它的目的是在训练初期使用较小的学习率,逐渐增加到设定的最大学习率,以帮助模型更好地收敛。 weight_decay=0.01, # 权重衰减的强度 logging_dir='./logs', # 保存训练日志的目录 load_best_model_at_end=True, # 这个参数指定了在训练结束后是否加载最佳模型。如果设置为True,则会加载具有最佳指标(默认是损失)的模型。 # but you can specify `metric_for_best_model` argument to change to accuracy or other metric logging_steps=100, # 这个参数指定了每隔多少步骤记录和保存日志信息。 save_steps=100, #这个参数指定了每隔多少步骤保存模型的权重。 eval_strategy="steps", # 每隔logging_steps步骤进行一次评估。 )
trainer = Trainer( model=model, # the instantiated Transformers model to be trained args=training_args, # training arguments, defined above train_dataset=train_dataset, # training dataset eval_dataset=valid_dataset, # evaluation dataset compute_metrics=compute_metrics, # the callback that computes metrics of interest )
之后对模型进行评估,保存最佳权重
1 2 3 4 5 6 7 8 9 10
# train the model trainer.train()
# evaluate the current model after training trainer.evaluate()
# saving the fine tuned model & tokenizer model_path = "20newsgroups-bert-base-uncased" model.save_pretrained(model_path) tokenizer.save_pretrained(model_path)
若报错:报错:ImportError: Using the Trainer with PyTorch requires accelerate>=0.20.1: Please run pip install transformers[torch] or pip install accelerate -U 方法:pip install accelerate -U
from transformers import BertForSequenceClassification, BertTokenizerFast from sklearn.model_selection import train_test_split from sklearn.datasets import fetch_20newsgroups
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(target_names)).to("cuda") tokenizer = BertTokenizerFast.from_pretrained(model_path)
defget_prediction(text): # prepare our text into tokenized sequence inputs = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to("cuda") # perform inference to our model outputs = model(**inputs) # get output probabilities by doing softmax probs = outputs[0].softmax(1) # executing argmax function to get the candidate label return target_names[probs.argmax()]
# Example #1 text = """With the pace of smartphone evolution moving so fast, there's always something waiting in the wings. No sooner have you spied the latest handset, that there's anticipation for the next big thing. Here we look at those phones that haven't yet launched, the upcoming phones for 2021. We'll be updating this list on a regular basis, with those device rumours we think are credible and exciting.""" print(get_prediction(text)) # Example #2 text = """ A black hole is a place in space where gravity pulls so much that even light can not get out. The gravity is so strong because matter has been squeezed into a tiny space. This can happen when a star is dying. Because no light can get out, people can't see black holes. They are invisible. Space telescopes with special tools can help find black holes. The special tools can see how stars that are very close to black holes act differently than other stars. """ print(get_prediction(text))
# Example #3 text = """ Coronavirus disease (COVID-19) is an infectious disease caused by a newly discovered coronavirus. Most people infected with the COVID-19 virus will experience mild to moderate respiratory illness and recover without requiring special treatment. Older people, and those with underlying medical problems like cardiovascular disease, diabetes, chronic respiratory disease, and cancer are more likely to develop serious illness. """ print(get_prediction(text))