Professional Documents
Culture Documents
FastAPI
Giới Thiệu
- - - - X
Mục đích của tài liệu này là cung cấp cho người đọc một phương pháp để
xây dựng và phát triển API sử dụng framework FastAPI.
Ví dụ được sử dụng trong tài liệu này là API đang được sử dụng cho dự
án KSCLOSET, source code tại đây.
Lưu ý: Có nhiều cách khác nhau để xây dựng và phát triển API sử dụng FastAPI,
tài liệu không nhằm mục đích thể hiện đây là cách làm tốt nhất và tối ưu
nhất, tùy theo mục đích sử dụng và dự án mà cách xây dựng và phát triển có
thể khác.
“
3
-FastAPI documentation
”
Cấu Trúc API
- - - - X
API sẽ có các chức năng chính sau
[GET] /version : xem thông tin phiên bản API và version của các thư viện.
[POST] /prediction/url: nhận input là 1 url, nếu URL là link ảnh thì tiến
hành xử lý và trả về kết quả là top 10 ảnh gần nhất.
[POST] /prediction/urls: nhận vào nhiều url, xử lý và trả về kết quả là top10
ảnh gần nhất của từng URL.
Sử dụng GET hoặc POST tùy vào từng API khác nhau. Ở đây, việc sử dụng
POST cho 3 phương pháp prediction/url, prediction/urls,
prediction/upload để đảm bảo thông tin gửi đến server được bảo mật,
tức là người gửi phải đưa thông tin muốn API xử lý vào trong body của
POST request.
Để thuận lợi phát triển API với mục đích có thể mở rộng trong tương
lai, cấu trúc gọn gàng, chúng ta sẽ sử dụng `APIRouter`[ref], có thể
coi router như các module con của API, mỗi module sẽ đảm nhận một
nhiệm vụ khác nhau tùy vào request của người sử dụng
[GET] /version
fastAPI_tutorial
├── app
├── __init__.py
└── main.py
fastAPI_tutorial
├── app
│ └── api
│ └── routers
│ └── get_version_api.py
├── __init__.py
└── main.py
1 import fastapi
2 router = fastapi.APIRouter()
@router.get("/version", response_class=JSONResponse)
5
async def get_api_version():
6
v = {"api_global_version": imageq_stamp}
return JSONResponse(status_code=200, content=v)
3 -4 : import các thư viện liên quan, JSONResponse sẽ convert kết quả trả về
thành dạng json, imageq_stamp là thông tin version (sẽ được mô tả ở phần
dưới).
Quan trọng: Đọc thêm về async def/ def để hiểu hơn về việc async/non-
async được sử dụng trong FastAPI, nên đọc vì tác giả giải thích rất
dễ hiểu
fastAPI_tutorial
├── app
│ ├── config
│ │ └── version.py
│ └── routers
│ └── api
│ └── get_version_api.py
├── __init__.py
└── main.py
Để sử dụng router đã khởi tạo, tiến hành import router trong file
main.py
5 project Kscloset")
app.include_router(get_version_api.router)
6
if __name__ == "__main__":
7 uvicorn.run(app, debug=True, host="0.0.0.0", port=8080)
8
[POST] /prediction/url
source code trên gitlab tại đây
router = fastapi.APIRouter()
5
from app.models.get_model_predict import get_predict
6 from app.response.response_input_output import PredictionOut, UrlIn
7 from app.services.axiom_services import upload_artifacts_job
8 from app.services.export_services import export_result
9
from app.processing_data.prepare_data_from_url import
10
get_data_from_url, prepare_data
from app.models import load_config_and_models
11
Các thư viện/file được import sẽ được mô tả rõ hơn ở phần CHÚ THÍCH
13: tạo async def handle_url sẽ nhận vào item là class với các tham số đã
được định nghĩa sẵn trong UrlIn.
BackgroundTasks: built-in của FastAPI, thực hiện các task background.
request: thông tin từ request của user( như là url, port,...), mục đích để
lưu ra log.
14-16 : tính toán processing time để lưu ra log
17: item là instance của UrlIn, truy cập mà user gửi vào thông qua item.url,
hàm get_data_from_url sẽ verify url, sau đó tiến hành download ảnh từ url.
18: function get_predict nhận input là ảnh đã được download, các models sod /
yolov5/ extract features để tính toán image similarity.
12 @router.post("/prediction/url", response_model=PredictionOut)
13 async def handle_url(item: UrlIn, background_tasks: BackgroundTasks,
request: Request):
"""
Get an url of image and perform predict the image similarity
10
try:
processing_time_log = dict()
start_t: float = time.time()
14
15 logger.info("This is routers/url_api")
16 logger.info(f"Input url: {item.url}")
17 image = get_data_from_url(item.url)
logger.info(f"result: \n {results}")
processing_time=processing_time_log,
)
20
Trong file main.py, giống như ở bước tạo version. Sử dụng include_router để
thêm router từ file url_api.py
if __name__ == "__main__":
uvicorn.run(app, debug=True, host="0.0.0.0", port=8080)
Đã có thể sử dụng API để xử lý request với link ảnh gửi vào từ user.
12
[POST] /prediction/urls
Khi xử lý request có nhiều URL, việc tạo router cũng tương tự như
method /prediction/url ở trên. Tuy nhiên điểm khác biệt ở đây là việc
sử dụng multi-threading trong quá trình tải ảnh (get_data_from_url) và
(get_predict)
13
1 @router.post("/prediction/urls")
async def handle_urls(
items: ListUrlIn, background_tasks: BackgroundTasks, request:
Request = None
):
"""
Handle multiple urls
:param items: Tuple of url, example: ('urls', ['url1', 'url2'])
:param background_tasks: Doing something in the background
:param request:
:return:
"""
start = time.time()
urls = [url for url in items.urls]
pil_images = get_data_from_url(items.urls)
2
3 try:
logger.info(f"input url: {urls}")
logger.info(f"max_workers: {max_workers}")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
fn = partial(
get_predict,
**{
"model_sod": load_config_and_models.model_sod,
"model_global": load_config_and_models.model_global,
"yolo_extractor":
load_config_and_models.yolo_extractor,
"sod_extractor":
load_config_and_models.sod_extractor,
"limit": items.limit,
},
)
response = list(executor.map(fn, pil_images))
except Exception as e:
# something wrong, we will return it here
raise HTTPException(status_code=403, detail=str(e))
14
if __name__ == "__main__":
# in production, don't forget to change reload => False, debug => False
uvicorn.run(app, debug=True, host="0.0.0.0", port=8080)
[POST] /prediction/upload
Đối với /prediction/upload, API sẽ cho phép người dùng upload ảnh để
tiến hành detect image similarity.
16
@router.post("/prediction/upload", response_model=PredictionOut)
async def handle_file_upload(
items: UploadLimit = Depends(),
file: UploadFile = File(...),
):
try:
processing_time_log = {}
start_t = time.time()
# read file from sender
data = await file.read()
# prepare the correct input (convert to image)
image = prepare_data(data)
get_data_t = time.time()
processing_time_log["get_data_t"] = get_data_t - start_t
except Exception as e:
# something wrong, we will return it here
logger.exception(e)
raise HTTPException(status_code=403, detail=str(e))
Sau khi include_router thì chạy lại API để thấy chức năng upload ảnh.
if __name__ == "__main__":
uvicorn.run(app, debug=True, host="0.0.0.0", port=8080)
Kết quả sau khi thêm upload router, end-user có thể test API thông qua button
Browse... để upload ảnh.
18
[GET] /reload
reload router có nhiệm vụ download file json từ server axiom.
File json chứa đường dẫn để download weights đã được train trên data
mới từ CMS.
@router.get("/reload", response_class=JSONResponse)
async def reload_models_route():
download_DMM_conf(
axiom_db=axiom_db,
env=ENV_DEPLOYMENT,
search_version=search_version,
save_path=os.path.dirname(DMM_CONF_PATH),
)
(
load_config_and_models.model_sod,
load_config_and_models.model_global,
load_config_and_models.yolo_extractor,
load_config_and_models.sod_extractor,
) = load_feature_extraction_models(load_all=True)
logger.info("Model reloaded successfully")
return {"Update_status": "Process completed"}
19
CHÚ THÍCH
Ví dụ: UrlIn yêu cầu url phải là str và limit là int. Nếu url được
truyền vào khi sử dụng URL không phải là str thì sẽ có thông báo lỗi.
PredictionOut khi trả về kết quả của API thì cần phải là List[Dict],
nếu sai cũng sẽ gây lỗi.
class UrlsOut(BaseModel):
prediction: Optional[Dict] = None
class UrlIn(BaseModel):
url: str = None
limit: int = 10
class ListUrlIn(BaseModel):
urls: Optional[List[str]] = Query(None)
limit: int = 10
20
class UploadLimit(BaseModel):
limit: int = 10
get_data_from_url
download ảnh từ url, convert từ dạng bytes sang PIL Image, source_code
if isinstance(image_url, list):
results = {}
futures = [session_futures.get(url, verify=False) for url in
image_url]
for future in as_completed(futures):
resp = future.result()
url = resp.request.url
results[url] = resp.content
else:
logger.exception(f"URL is invalid! Please check URL {image_url}")
return {}
# pil_image = prepare_data(image_bytes)
return [prepare_data(image_bytes) for image_bytes in images_bytes]
get_predict
Source_code
21
Returns
-------
top 10 prediction of the similarities
"""
model_sod = kwargs.get("model_sod", None)
model_global = kwargs.get("model_global", None)
yolo_extractor = kwargs.get("yolo_extractor", None)
sod_extractor = kwargs.get("sod_extractor", None)
limit = kwargs.get("limit", 10)
if limit < 1:
limit = 10
try:
results = yolo_results = global_results = sod_results = []
search_t = time.time()
# people localization using Yolov5
list_regions = yolo_extractor.predict(image, img_size=512)
sod_time = time.time() - search_t
start_search = time.time()
# keep only instance with score >= 0.95
# if none of them >= 0.95, then performing global search
if list_regions:
# Found people in the image
# get result with score >= 0.95
yolo_results = [
22
if len(use_yolo_results) == 1:
# only 1 instance with score >= 0.95
logger.info("Results from 1 instance of Yolo")
results = restructure_output(
instance_results=use_yolo_results[0], limit=limit
)
elif len(use_yolo_results) > 1:
# more than 2 instances with score >= 0.95
logger.info("Results from more than 2 instances of
Yolo")
results = restructure_output_multi_instances(
use_yolo_results, limit=limit
)
if not results:
# if results from localization is empty or score less than 0.95
# then apply global search
image_resize = (
pil_resize(image[0]) if isinstance(image, list) else
pil_resize(image)
)
global_results = model_global.search_topk(image_resize, limit)
if global_results["0"][1] < 0.95:
# if global search score < 0.95 then call SOD
_, _, _, regions =
sod_extractor.extract(img=image_resize)
# get topk from all regions of SOD then get the max
sod_all_results = [
model_sod.search_topk(region, limit, use_dmm=True)
for region in regions
]
sod_results = max(sod_all_results, key=lambda result:
result["0"][1])
23
if not results:
results = combine_all_results(
yolo_results, global_results, sod_results, limit=limit
)
except Exception as e:
logger.exception(e)
get_data_from_url
source_code
if isinstance(image_url, list):
results = {}
futures = [session_futures.get(url, verify=False) for url in
image_url]
for future in as_completed(futures):
resp = future.result()
url = resp.request.url
results[url] = resp.content
else:
logger.exception(f"URL is invalid! Please check URL {image_url}")
return {}
return [prepare_data(image_bytes) for image_bytes in images_bytes]
prepare_data
def prepare_data(contents):
try:
logger.info("Preparing data for model similarity")
pil_image = Image.open(io.BytesIO(contents)).convert("RGB")
# pil_image = pil_resize(pil_image)
return pil_image
except Exception as e:
logger.exception(f"Cannot prepare data, detail: {e}")