Made-With-ML/madewithml/serve.py

77 lines
2.6 KiB
Python
Raw Normal View History

2023-07-26 13:53:11 +02:00
import argparse
import os
2023-07-26 13:53:11 +02:00
from http import HTTPStatus
from typing import Dict
import ray
from fastapi import FastAPI
from ray import serve
from starlette.requests import Request
from madewithml import evaluate, predict
from madewithml.config import MLFLOW_TRACKING_URI, mlflow
# Define application
app = FastAPI(
title="Made With ML",
description="Classify machine learning projects.",
version="0.1",
)
2023-09-19 07:03:20 +02:00
@serve.deployment(num_replicas="1", ray_actor_options={"num_cpus": 8, "num_gpus": 0})
2023-07-26 13:53:11 +02:00
@serve.ingress(app)
class ModelDeployment:
def __init__(self, run_id: str, threshold: int = 0.9):
"""Initialize the model."""
self.run_id = run_id
self.threshold = threshold
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) # so workers have access to model registry
best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
2023-09-19 07:03:20 +02:00
self.predictor = predict.TorchPredictor.from_checkpoint(best_checkpoint)
2023-07-26 13:53:11 +02:00
@app.get("/")
def _index(self) -> Dict:
"""Health check."""
response = {
"message": HTTPStatus.OK.phrase,
"status-code": HTTPStatus.OK,
"data": {},
}
return response
@app.get("/run_id/")
def _run_id(self) -> Dict:
"""Get the run ID."""
return {"run_id": self.run_id}
@app.post("/evaluate/")
async def _evaluate(self, request: Request) -> Dict:
data = await request.json()
results = evaluate.evaluate(run_id=self.run_id, dataset_loc=data.get("dataset"))
return {"results": results}
@app.post("/predict/")
2023-09-19 07:03:20 +02:00
async def _predict(self, request: Request):
2023-07-26 13:53:11 +02:00
data = await request.json()
2023-09-19 07:03:20 +02:00
sample_ds = ray.data.from_items([{"title": data.get("title", ""), "description": data.get("description", ""), "tag": ""}])
results = predict.predict_proba(ds=sample_ds, predictor=self.predictor)
2023-07-26 13:53:11 +02:00
# Apply custom logic
for i, result in enumerate(results):
pred = result["prediction"]
prob = result["probabilities"]
if prob[pred] < self.threshold:
results[i]["prediction"] = "other"
return {"results": results}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--run_id", help="run ID to use for serving.")
parser.add_argument("--threshold", type=float, default=0.9, help="threshold for `other` class.")
args = parser.parse_args()
ray.init(runtime_env={"env_vars": {"GITHUB_USERNAME": os.environ["GITHUB_USERNAME"]}})
2023-07-26 13:53:11 +02:00
serve.run(ModelDeployment.bind(run_id=args.run_id, threshold=args.threshold))