Merge pull request #236 from GokuMohandas/dev

fixed predict with probs error
This commit is contained in:
Goku Mohandas 2023-08-03 22:33:11 -07:00 committed by GitHub
commit 0cfb704d8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -2,6 +2,7 @@ import json
from typing import Any, Dict, Iterable, List
from urllib.parse import urlparse
import numpy as np
import pandas as pd
import ray
import torch
@ -62,8 +63,6 @@ def predict_with_proba(
"""
preprocessor = predictor.get_preprocessor()
z = predictor.predict(data=df)["predictions"]
import numpy as np
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
results = []
for i, prob in enumerate(y_prob):
@ -130,7 +129,7 @@ def predict(
# Predict
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class)
results = predict_with_proba(df=sample_df, predictor=predictor)
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
return results