Paul DAMPFHOEFFER commited on
Commit
b5ed9fd
·
1 Parent(s): c800f36

feat: init flash api

Browse files
Files changed (2) hide show
  1. app.py +85 -2
  2. requirements.txt +10 -1
app.py CHANGED
@@ -1,7 +1,90 @@
1
- from fastapi import FastAPI
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
  @app.get("/")
6
  def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ from PIL import Image
4
+
5
+ from transformers import AriaProcessor, AriaForConditionalGeneration
6
+ from fastapi import FastAPI, Request
7
 
8
  app = FastAPI()
9
 
10
  @app.get("/")
11
  def greet_json():
12
+ return {"Hello": "World!"}
13
+
14
+ @app.post("/")
15
+ async def aria_image_to_text(request: Request):
16
+ data = await request.json()
17
+ image_url = data.get("image_url")
18
+ image = Image.open(requests.get(image_url, stream=True).raw)
19
+
20
+ model_id_or_path = "rhymes-ai/Aria"
21
+ model = AriaForConditionalGeneration.from_pretrained(
22
+ model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
23
+ )
24
+
25
+ processor = AriaProcessor.from_pretrained(model_id_or_path)
26
+
27
+ messages = [
28
+ {
29
+ "role": "user",
30
+ "content": [
31
+ {"type": "image"},
32
+ {"text": "what is the image?", "type": "text"},
33
+ ],
34
+ }
35
+ ]
36
+
37
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
38
+ inputs = processor(text=text, images=image, return_tensors="pt")
39
+ inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
40
+ inputs.to(model.device)
41
+
42
+ output = model.generate(
43
+ **inputs,
44
+ max_new_tokens=15,
45
+ stop_strings=["<|im_end|>"],
46
+ tokenizer=processor.tokenizer,
47
+ do_sample=True,
48
+ temperature=0.9,
49
+ )
50
+ output_ids = output[0][inputs["input_ids"].shape[1]:]
51
+ response = processor.decode(output_ids, skip_special_tokens=True)
52
+ return {"response": response}
53
+
54
+ @app.get("/aria-test")
55
+ def aria_test():
56
+ model_id_or_path = "rhymes-ai/Aria"
57
+ model = AriaForConditionalGeneration.from_pretrained(
58
+ model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
59
+ )
60
+
61
+ processor = AriaProcessor.from_pretrained(model_id_or_path)
62
+
63
+ image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
64
+
65
+ messages = [
66
+ {
67
+ "role": "user",
68
+ "content": [
69
+ {"type": "image"},
70
+ {"text": "what is the image?", "type": "text"},
71
+ ],
72
+ }
73
+ ]
74
+
75
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
76
+ inputs = processor(text=text, images=image, return_tensors="pt")
77
+ inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
78
+ inputs.to(model.device)
79
+
80
+ output = model.generate(
81
+ **inputs,
82
+ max_new_tokens=15,
83
+ stop_strings=["<|im_end|>"],
84
+ tokenizer=processor.tokenizer,
85
+ do_sample=True,
86
+ temperature=0.9,
87
+ )
88
+ output_ids = output[0][inputs["input_ids"].shape[1]:]
89
+ response = processor.decode(output_ids, skip_special_tokens=True)
90
+ return {"response": response}
requirements.txt CHANGED
@@ -1,2 +1,11 @@
1
  fastapi
2
- uvicorn[standard]
 
 
 
 
 
 
 
 
 
 
1
  fastapi
2
+ uvicorn[standard]
3
+ transformers>=4.48.0
4
+ accelerate
5
+ sentencepiece
6
+ torchvision
7
+ requests
8
+ torch
9
+ Pillow
10
+ flash-attn
11
+ grouped_gemm==0.1.6