fclong's picture
Upload 396 files
8ebda9e
raw
history blame
4.99 kB
import argparse
from fengshen import UbertPipelines
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6'
def main():
total_parser = argparse.ArgumentParser("TASK NAME")
total_parser = UbertPipelines.pipelines_args(total_parser)
args = total_parser.parse_args()
# 设置一些训练要使用到的参数
args.pretrained_model_path = 'IDEA-CCNL/Erlangshen-Ubert-110M-Chinese' #预训练模型的路径,我们提供的预训练模型存放在HuggingFace上
args.default_root_dir = './' #默认主路径,用来放日志、tensorboard等
args.max_epochs = 5
args.gpus = 1
args.batch_size = 1
# 只需要将数据处理成为下面数据的 json 样式就可以一键训练和预测,下面只是提供了一条示例样本
train_data = [
{
"task_type": "抽取任务",
"subtask_type": "实体识别",
"text": "彭小军认为,国内银行现在走的是台湾的发卡模式,先通过跑马圈地再在圈的地里面选择客户,",
"choices": [
{"entity_type": "地址", "label": 0, "entity_list": [
{"entity_name": "台湾", "entity_type": "地址", "entity_idx": [[15, 16]]}]},
{"entity_type": "书名", "label": 0, "entity_list": []},
{"entity_type": "公司", "label": 0, "entity_list": []},
{"entity_type": "游戏", "label": 0, "entity_list": []},
{"entity_type": "政府机构", "label": 0, "entity_list": []},
{"entity_type": "电影名称", "label": 0, "entity_list": []},
{"entity_type": "人物姓名", "label": 0, "entity_list": [
{"entity_name": "彭小军", "entity_type": "人物姓名", "entity_idx": [[0, 2]]}]},
{"entity_type": "组织机构", "label": 0, "entity_list": []},
{"entity_type": "岗位职位", "label": 0, "entity_list": []},
{"entity_type": "旅游景点", "label": 0, "entity_list": []}
],
"id": 0}
]
dev_data = [
{
"task_type": "抽取任务",
"subtask_type": "实体识别",
"text": "就天涯网推出彩票服务频道是否是业内人士所谓的打政策“擦边球”,记者近日对此事求证彩票监管部门。",
"choices": [
{"entity_type": "地址", "label": 0, "entity_list": []},
{"entity_type": "书名", "label": 0, "entity_list": []},
{"entity_type": "公司", "label": 0, "entity_list": [
{"entity_name": "天涯网", "entity_type": "公司", "entity_idx": [[1, 3]]}]},
{"entity_type": "游戏", "label": 0, "entity_list": []},
{"entity_type": "政府机构", "label": 0, "entity_list": []},
{"entity_type": "电影名称", "label": 0, "entity_list": []},
{"entity_type": "人物姓名", "label": 0, "entity_list": []},
{"entity_type": "组织机构", "label": 0, "entity_list": [
{"entity_name": "彩票监管部门", "entity_type": "组织机构", "entity_idx": [[40, 45]]}]},
{"entity_type": "岗位职位", "label": 0, "entity_list": [
{"entity_name": "记者", "entity_type": "岗位职位", "entity_idx": [[31, 32]]}]},
{"entity_type": "旅游景点", "label": 0, "entity_list": []}
],
"id": 0}
]
test_data = [
{
"task_type": "抽取任务",
"subtask_type": "实体识别",
"text": "这也让很多业主据此认为,雅清苑是政府公务员挤对了国家的经适房政策。",
"choices": [
{"entity_type": "地址", "label": 0, "entity_list": [
{"entity_name": "雅清苑", "entity_type": "地址", "entity_idx": [[12, 14]]}]},
{"entity_type": "书名", "label": 0, "entity_list": []},
{"entity_type": "公司", "label": 0, "entity_list": []},
{"entity_type": "游戏", "label": 0, "entity_list": []},
{"entity_type": "政府机构", "label": 0, "entity_list": []},
{"entity_type": "电影名称", "label": 0, "entity_list": []},
{"entity_type": "人物姓名", "label": 0, "entity_list": []},
{"entity_type": "组织机构", "label": 0, "entity_list": []},
{"entity_type": "岗位职位", "label": 0, "entity_list": [
{"entity_name": "公务员", "entity_type": "岗位职位", "entity_idx": [[18, 20]]}]},
{"entity_type": "旅游景点", "label": 0, "entity_list": []}
],
"id": 0},
]
model = UbertPipelines(args)
model.fit(train_data, dev_data)
result = model.predict(test_data)
for line in result:
print(line)
if __name__ == "__main__":
main()