fixed predict with probs error
This commit is contained in:
parent
841fb4dc36
commit
d7f28223d5
@ -2,6 +2,7 @@ import json
|
||||
from typing import Any, Dict, Iterable, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
@ -62,8 +63,6 @@ def predict_with_proba(
|
||||
"""
|
||||
preprocessor = predictor.get_preprocessor()
|
||||
z = predictor.predict(data=df)["predictions"]
|
||||
import numpy as np
|
||||
|
||||
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
|
||||
results = []
|
||||
for i, prob in enumerate(y_prob):
|
||||
@ -130,7 +129,7 @@ def predict(
|
||||
|
||||
# Predict
|
||||
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
|
||||
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class)
|
||||
results = predict_with_proba(df=sample_df, predictor=predictor)
|
||||
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
|
||||
return results
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user