updated to Ray 2.7

This commit is contained in:
GokuMohandas
2023-09-18 22:03:20 -07:00
parent 71b3d50a05
commit b98bd5b1ae
15 changed files with 3484 additions and 2086 deletions

View File

@@ -8,13 +8,13 @@ import ray
import ray.train.torch # NOQA: F401 (imported but unused)
import typer
from ray.data import Dataset
from ray.train.torch.torch_predictor import TorchPredictor
from sklearn.metrics import precision_recall_fscore_support
from snorkel.slicing import PandasSFApplier, slicing_function
from typing_extensions import Annotated
from madewithml import predict, utils
from madewithml.config import logger
from madewithml.predict import TorchPredictor
# Initialize Typer CLI app
app = typer.Typer()
@@ -133,8 +133,8 @@ def evaluate(
y_true = np.stack([item["targets"] for item in values])
# y_pred
z = predictor.predict(data=ds.to_pandas())["predictions"]
y_pred = np.stack(z).argmax(1)
predictions = preprocessed_ds.map_batches(predictor).take_all()
y_pred = np.array([d["output"] for d in predictions])
# Metrics
metrics = {