Spaces:
Runtime error
Runtime error
import os | |
import mmcv | |
import numpy as np | |
from app.routers.image import inferenceImage | |
import pytest | |
import pytest | |
from fastapi.testclient import TestClient | |
from fastapi.routing import APIRoute | |
from app.main import app | |
def endpoints(): | |
endpoints = [] | |
for route in app.routes: | |
if isinstance(route, APIRoute): | |
endpoints.append(route.path) | |
return endpoints | |
def client(): | |
client = TestClient(app, "http://0.0.0.0:3000") | |
yield client | |
class TestImageRoute(): | |
img = mmcv.imread('demo.jpg') | |
url = "http://0.0.0.0:3000/image" | |
def test_inferenceImage(self): | |
bboxes, labels = inferenceImage(mmcv.imread('demo.jpg'), 0.3, True) | |
assert len(bboxes.tolist()) > 0 and len(labels.tolist()) > 0 and len(bboxes.tolist()) == len(labels.tolist()) | |
result = inferenceImage(self.img, 0.3, False) | |
assert type(result) is np.ndarray and result.shape == self.img.shape | |
def test_ImageAPI(self, client): | |
payload = {} | |
files=[ | |
('file',('demo.jpg',open('demo.jpg','rb'),'image/jpeg')) | |
] | |
headers = { | |
'accept': 'application/json' | |
} | |
response = client.request("POST", "image", headers=headers, data=payload, files=files) | |
result = mmcv.imfrombytes(response.read()) | |
assert response.status_code == 200 and result.shape == self.img.shape | |
def test_ImageAPI_one_channel_array(self, client): | |
np.zeros((1,640,640)).dump("one_channel.jpg") | |
payload = {} | |
files=[ | |
('file', ("demo.jpg",open("one_channel.jpg", "rb"),'image/jpeg')) | |
] | |
headers = { | |
'accept': 'application/json' | |
} | |
response = client.request("POST", "image", headers=headers, data=payload, files=files) | |
assert response.status_code != 200 | |
def test_ImageAPIWithThresHold(self, client): | |
payload = {} | |
files=[ | |
('file',('demo.jpg',open('demo.jpg','rb'),'image/jpeg')) | |
] | |
headers = { | |
'accept': 'application/json' | |
} | |
response = client.request("POST", "image?threshold=1&raw=True", headers=headers, data=payload, files=files) | |
thresHold = 0.4 | |
assert response.status_code == 200 # The result with threshold equal 0 is 0 | |
# No detected object has 100% accuracy | |
assert len(response.json()["labels"]) == 0 | |
payload = {} | |
files=[ | |
('file',('demo.jpg',open('demo.jpg','rb'),'image/jpeg')) | |
] | |
headers = { | |
'accept': 'application/json' | |
} | |
response = client.request("POST", "image?threshold=" + str(thresHold) + "&raw=True", headers=headers, data=payload, files=files) | |
assert response.status_code == 200 | |
for bbox in response.json()['bboxes']: | |
assert bbox[4] >= thresHold | |