updated to Ray 2.7
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from ray.train.torch.torch_predictor import TorchPredictor
|
||||
|
||||
from madewithml import predict
|
||||
from madewithml.predict import TorchPredictor
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
|
||||
from madewithml import predict
|
||||
|
||||
|
||||
def get_label(text, predictor):
|
||||
df = pd.DataFrame({"title": [text], "description": "", "tag": "other"})
|
||||
z = predictor.predict(data=df)["predictions"]
|
||||
preprocessor = predictor.get_preprocessor()
|
||||
label = predict.decode(np.stack(z).argmax(1), preprocessor.index_to_class)[0]
|
||||
return label
|
||||
sample_ds = ray.data.from_items([{"title": text, "description": "", "tag": "other"}])
|
||||
results = predict.predict_proba(ds=sample_ds, predictor=predictor)
|
||||
return results[0]["prediction"]
|
||||
|
||||
Reference in New Issue
Block a user