updated to Ray 2.7

This commit is contained in:
GokuMohandas
2023-09-18 22:03:20 -07:00
parent 71b3d50a05
commit b98bd5b1ae
15 changed files with 3484 additions and 2086 deletions

View File

@@ -1,19 +1,20 @@
import json
from pathlib import Path
from typing import Any, Dict, Iterable, List
from urllib.parse import urlparse
import numpy as np
import pandas as pd
import ray
import torch
import typer
from numpyencoder import NumpyEncoder
from ray.air import Result
from ray.train.torch import TorchPredictor
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from typing_extensions import Annotated
from madewithml.config import logger, mlflow
from madewithml.data import CustomPreprocessor
from madewithml.models import FinetunedLLM
from madewithml.utils import collate_fn
# Initialize Typer CLI app
app = typer.Typer()
@@ -48,25 +49,51 @@ def format_prob(prob: Iterable, index_to_class: Dict) -> Dict:
return d
def predict_with_proba(
df: pd.DataFrame,
predictor: ray.train.torch.torch_predictor.TorchPredictor,
class TorchPredictor:
def __init__(self, preprocessor, model):
self.preprocessor = preprocessor
self.model = model
self.model.eval()
def __call__(self, batch):
results = self.model.predict(collate_fn(batch))
return {"output": results}
def predict_proba(self, batch):
results = self.model.predict_proba(collate_fn(batch))
return {"output": results}
def get_preprocessor(self):
return self.preprocessor
@classmethod
def from_checkpoint(cls, checkpoint):
metadata = checkpoint.get_metadata()
preprocessor = CustomPreprocessor(class_to_index=metadata["class_to_index"])
model = FinetunedLLM.load(Path(checkpoint.path, "args.json"), Path(checkpoint.path, "model.pt"))
return cls(preprocessor=preprocessor, model=model)
def predict_proba(
ds: ray.data.dataset.Dataset,
predictor: TorchPredictor,
) -> List: # pragma: no cover, tested with inference workload
"""Predict tags (with probabilities) for input data from a dataframe.
Args:
df (pd.DataFrame): dataframe with input features.
predictor (ray.train.torch.torch_predictor.TorchPredictor): loaded predictor from a checkpoint.
predictor (TorchPredictor): loaded predictor from a checkpoint.
Returns:
List: list of predicted labels.
"""
preprocessor = predictor.get_preprocessor()
z = predictor.predict(data=df)["predictions"]
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
preprocessed_ds = preprocessor.transform(ds)
outputs = preprocessed_ds.map_batches(predictor.predict_proba)
y_prob = np.array([d["output"] for d in outputs.take_all()])
results = []
for i, prob in enumerate(y_prob):
tag = decode([z[i].argmax()], preprocessor.index_to_class)[0]
tag = preprocessor.index_to_class[prob.argmax()]
results.append({"prediction": tag, "probabilities": format_prob(prob, preprocessor.index_to_class)})
return results
@@ -125,11 +152,10 @@ def predict(
# Load components
best_checkpoint = get_best_checkpoint(run_id=run_id)
predictor = TorchPredictor.from_checkpoint(best_checkpoint)
# preprocessor = predictor.get_preprocessor()
# Predict
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
results = predict_with_proba(df=sample_df, predictor=predictor)
sample_ds = ray.data.from_items([{"title": title, "description": description, "tag": "other"}])
results = predict_proba(ds=sample_ds, predictor=predictor)
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
return results