|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
import shutil |
|
from io import BytesIO |
|
import torch |
|
from PIL import Image |
|
|
|
from vtoonify_model import cacartoon1 |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
model = cacartoon1(device='cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def generate_cartoon(image_bytes: bytes) -> bytes: |
|
image = Image.open(BytesIO(image_bytes)) |
|
cartoon_image = model.generate_cartoon(image) |
|
with BytesIO() as output: |
|
cartoon_image.save(output, format="PNG") |
|
return output.getvalue() |
|
|
|
@app.post("/upload/") |
|
async def upload_image(file: UploadFile = File(...)): |
|
contents = await file.read() |
|
result_bytes = generate_cartoon(contents) |
|
return {"result": result_bytes} |
|
|
|
app.mount("/", StaticFiles(directory="AB", html=True), name="static") |
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="/app/AB/index.html", media_type="text/html") |
|
|