File size: 3,583 Bytes
e500bb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import json
import uvicorn
from controlnet_aux import LineartDetector
from fastapi import FastAPI, Request
from utils import assemble_response, torch_gc
from PIL import Image
from models.ssm import GroundedSAM, DensePose, Graphonomy
from models.vfm import BLIP, InpaintControlNet

parser = argparse.ArgumentParser(description='Start Server')
parser.add_argument('-P', '--port', type=int, default=8123, help='Port')
parser.add_argument('-H', '--host', type=str, default='0.0.0.0', help='Host IP')
app = FastAPI()


@app.post("/blip")
async def blip(request: Request):
    global model_blip
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)

    image_path = json_post_list.get('image_path')
    question = json_post_list.get('question')
    answer = model_blip.vqa(image_path, question)
    answer = assemble_response({'response': answer})
    torch_gc(model_blip.device)
    return answer


@app.post("/controlnet")
async def controlnet(request: Request):
    global model_cn
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    kwargs = dict(json_post_list)
    answer = assemble_response({'response': model_cn(**kwargs)})
    torch_gc(model_cn.device)
    return answer


@app.post("/lineart")
async def lineart(request: Request):
    global model_lineart
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    kwargs = dict(json_post_list)
    result_path = kwargs['input_image'].replace(".", "-lineart.")
    kwargs['input_image'] = Image.open(kwargs['input_image']).convert('RGB')
    lineart_image = model_lineart(**kwargs)
    lineart_image.save(result_path)
    answer = assemble_response({'response': result_path})
    return answer


@app.post("/graph")
async def graph(request: Request):
    global model_graph
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    image_path = json_post_list.get('image_path')
    response = model_graph(image_path=image_path)
    answer = assemble_response({"response": response})
    torch_gc(model_graph.device)
    return answer


@app.post("/densepose")
async def dense_post(request: Request):
    global model_dense
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    image_path = json_post_list.get('image_path')
    response = model_dense(image_path=image_path)
    answer = assemble_response({"response": response})
    torch_gc(model_dense.device)
    return answer


@app.post("/segment")
async def segment(request: Request):
    global model_matting
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)

    result_path = model_matting(**dict(json_post_list))
    answer = assemble_response({'response': result_path})
    torch_gc(model_matting.device)
    return answer


if __name__ == '__main__':
    args = parser.parse_args()

    # Vision Foundation Models
    model_blip = BLIP(device="cuda")
    model_cn = InpaintControlNet(device="cuda")
    model_lineart = LineartDetector.from_pretrained("./checkpoints/Annotators")

    # Human Parsing Models
    model_graph = Graphonomy(device='cuda')
    model_dense = DensePose(device='cuda')
    model_matting = GroundedSAM(device='cuda')

    uvicorn.run(app, host=args.host, port=args.port, workers=1)