Merge pull request #236 from GokuMohandas/dev

fixed predict with probs error
This commit is contained in:
Goku Mohandas 2023-08-03 22:33:11 -07:00 committed by GitHub
commit 0cfb704d8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,7 @@ import json
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np
import pandas as pd import pandas as pd
import ray import ray
import torch import torch
@ -62,8 +63,6 @@ def predict_with_proba(
""" """
preprocessor = predictor.get_preprocessor() preprocessor = predictor.get_preprocessor()
z = predictor.predict(data=df)["predictions"] z = predictor.predict(data=df)["predictions"]
import numpy as np
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy() y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
results = [] results = []
for i, prob in enumerate(y_prob): for i, prob in enumerate(y_prob):
@ -130,7 +129,7 @@ def predict(
# Predict # Predict
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}]) sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class) results = predict_with_proba(df=sample_df, predictor=predictor)
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2)) logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
return results return results