fixed predict with probs error

This commit is contained in:
GokuMohandas 2023-08-03 22:32:49 -07:00
parent 841fb4dc36
commit d7f28223d5

View File

@ -2,6 +2,7 @@ import json
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np
import pandas as pd import pandas as pd
import ray import ray
import torch import torch
@ -62,8 +63,6 @@ def predict_with_proba(
""" """
preprocessor = predictor.get_preprocessor() preprocessor = predictor.get_preprocessor()
z = predictor.predict(data=df)["predictions"] z = predictor.predict(data=df)["predictions"]
import numpy as np
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy() y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
results = [] results = []
for i, prob in enumerate(y_prob): for i, prob in enumerate(y_prob):
@ -130,7 +129,7 @@ def predict(
# Predict # Predict
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}]) sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class) results = predict_with_proba(df=sample_df, predictor=predictor)
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2)) logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
return results return results