From d7f28223d52ea81e174345c177c921c79f830db0 Mon Sep 17 00:00:00 2001 From: GokuMohandas Date: Thu, 3 Aug 2023 22:32:49 -0700 Subject: [PATCH] fixed predict with probs error --- madewithml/predict.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/madewithml/predict.py b/madewithml/predict.py index c530488..b97e9f8 100644 --- a/madewithml/predict.py +++ b/madewithml/predict.py @@ -2,6 +2,7 @@ import json from typing import Any, Dict, Iterable, List from urllib.parse import urlparse +import numpy as np import pandas as pd import ray import torch @@ -62,8 +63,6 @@ def predict_with_proba( """ preprocessor = predictor.get_preprocessor() z = predictor.predict(data=df)["predictions"] - import numpy as np - y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy() results = [] for i, prob in enumerate(y_prob): @@ -130,7 +129,7 @@ def predict( # Predict sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}]) - results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class) + results = predict_with_proba(df=sample_df, predictor=predictor) logger.info(json.dumps(results, cls=NumpyEncoder, indent=2)) return results