updated to Ray 2.7
This commit is contained in:
@@ -8,13 +8,13 @@ import ray
|
||||
import ray.train.torch # NOQA: F401 (imported but unused)
|
||||
import typer
|
||||
from ray.data import Dataset
|
||||
from ray.train.torch.torch_predictor import TorchPredictor
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
from snorkel.slicing import PandasSFApplier, slicing_function
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from madewithml import predict, utils
|
||||
from madewithml.config import logger
|
||||
from madewithml.predict import TorchPredictor
|
||||
|
||||
# Initialize Typer CLI app
|
||||
app = typer.Typer()
|
||||
@@ -133,8 +133,8 @@ def evaluate(
|
||||
y_true = np.stack([item["targets"] for item in values])
|
||||
|
||||
# y_pred
|
||||
z = predictor.predict(data=ds.to_pandas())["predictions"]
|
||||
y_pred = np.stack(z).argmax(1)
|
||||
predictions = preprocessed_ds.map_batches(predictor).take_all()
|
||||
y_pred = np.array([d["output"] for d in predictions])
|
||||
|
||||
# Metrics
|
||||
metrics = {
|
||||
|
||||
Reference in New Issue
Block a user