Merge pull request #236 from GokuMohandas/dev
fixed predict with probs error
This commit is contained in:
commit
0cfb704d8b
@ -2,6 +2,7 @@ import json
|
|||||||
from typing import Any, Dict, Iterable, List
|
from typing import Any, Dict, Iterable, List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -62,8 +63,6 @@ def predict_with_proba(
|
|||||||
"""
|
"""
|
||||||
preprocessor = predictor.get_preprocessor()
|
preprocessor = predictor.get_preprocessor()
|
||||||
z = predictor.predict(data=df)["predictions"]
|
z = predictor.predict(data=df)["predictions"]
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
|
y_prob = torch.tensor(np.stack(z)).softmax(dim=1).numpy()
|
||||||
results = []
|
results = []
|
||||||
for i, prob in enumerate(y_prob):
|
for i, prob in enumerate(y_prob):
|
||||||
@ -130,7 +129,7 @@ def predict(
|
|||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
|
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
|
||||||
results = predict_with_proba(df=sample_df, predictor=predictor, index_to_class=preprocessor.index_to_class)
|
results = predict_with_proba(df=sample_df, predictor=predictor)
|
||||||
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
|
logger.info(json.dumps(results, cls=NumpyEncoder, indent=2))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user