File size: 2,301 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import uvicorn
import click
import argparse
import json
from importlib import import_module
from fastapi import FastAPI, WebSocket
from starlette.middleware.cors import CORSMiddleware
from utils import user_config, api_logger, setup_logger, RequestDataStructure

# 命令行启动时只输入一个参数,即配置文件的名字,eg: text_classification.json
# 其余所有配置在该配置文件中设定,不在命令行中指定
total_parser = argparse.ArgumentParser("API")
total_parser.add_argument("config_path", type=str)
args = total_parser.parse_args()

# set up user config
user_config.setup_config(args)

# set up logger
setup_logger(api_logger, user_config)

# load pipeline 
pipeline_class = getattr(import_module('fengshen.pipelines.' + user_config.pipeline_type), 'Pipeline')
model_settings = user_config.model_settings
model_args = argparse.Namespace(**model_settings)
pipeline = pipeline_class(
    args = model_args,
    model = user_config.model_name
    )


# initialize app
app = FastAPI(
    title = user_config.PROJECT_NAME, 
    openapi_url = f"{user_config.API_PREFIX_STR}/openapi.json"
    )


# api 
# TODO 
# 需要针对不同请求方法做不同判断,目前仅跑通了较通用的POST方法
# POST方法可以完成大多数 输入文本-返回结果 的请求任务
if(user_config.API_method == "POST"):
    @app.post(user_config.API_path, tags = user_config.API_tags)
    async def fengshen_post(data:RequestDataStructure):
        # logging
        api_logger.info(data.input_text)

        input_text = data.input_text

        result = pipeline(input_text)

        return result
else:
    print("only support POST method")



# Set all CORS enabled origins
if user_config.BACKEND_CORS_ORIGINS:
    app.add_middleware(
        CORSMiddleware,
        allow_origins = [str(origin) for origin in user_config.BACKEND_CORS_ORIGINS],
        allow_credentials = user_config.allow_credentials,
        allow_methods = user_config.allow_methods,
        allow_headers = user_config.allow_headers,
    )


if __name__ == '__main__':

    # 启动后可在浏览器打开 host:port/docs 查看接口的具体信息,并可进行简单测试
    # eg: 127.0.0.1:8990/docs
    uvicorn.run(app, host = user_config.SERVER_HOST, port = user_config.SERVER_PORT)