Made-With-ML/madewithml/serve.py
2023-09-18 22:03:20 -07:00

77 lines
2.6 KiB
Python

import argparse
import os
from http import HTTPStatus
from typing import Dict
import ray
from fastapi import FastAPI
from ray import serve
from starlette.requests import Request
from madewithml import evaluate, predict
from madewithml.config import MLFLOW_TRACKING_URI, mlflow
# Define application
app = FastAPI(
title="Made With ML",
description="Classify machine learning projects.",
version="0.1",
)
@serve.deployment(num_replicas="1", ray_actor_options={"num_cpus": 8, "num_gpus": 0})
@serve.ingress(app)
class ModelDeployment:
def __init__(self, run_id: str, threshold: int = 0.9):
"""Initialize the model."""
self.run_id = run_id
self.threshold = threshold
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) # so workers have access to model registry
best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
self.predictor = predict.TorchPredictor.from_checkpoint(best_checkpoint)
@app.get("/")
def _index(self) -> Dict:
"""Health check."""
response = {
"message": HTTPStatus.OK.phrase,
"status-code": HTTPStatus.OK,
"data": {},
}
return response
@app.get("/run_id/")
def _run_id(self) -> Dict:
"""Get the run ID."""
return {"run_id": self.run_id}
@app.post("/evaluate/")
async def _evaluate(self, request: Request) -> Dict:
data = await request.json()
results = evaluate.evaluate(run_id=self.run_id, dataset_loc=data.get("dataset"))
return {"results": results}
@app.post("/predict/")
async def _predict(self, request: Request):
data = await request.json()
sample_ds = ray.data.from_items([{"title": data.get("title", ""), "description": data.get("description", ""), "tag": ""}])
results = predict.predict_proba(ds=sample_ds, predictor=self.predictor)
# Apply custom logic
for i, result in enumerate(results):
pred = result["prediction"]
prob = result["probabilities"]
if prob[pred] < self.threshold:
results[i]["prediction"] = "other"
return {"results": results}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--run_id", help="run ID to use for serving.")
parser.add_argument("--threshold", type=float, default=0.9, help="threshold for `other` class.")
args = parser.parse_args()
ray.init(runtime_env={"env_vars": {"GITHUB_USERNAME": os.environ["GITHUB_USERNAME"]}})
serve.run(ModelDeployment.bind(run_id=args.run_id, threshold=args.threshold))