updated to Ray 2.7

This commit is contained in:
GokuMohandas
2023-09-18 22:03:20 -07:00
parent 71b3d50a05
commit b98bd5b1ae
15 changed files with 3484 additions and 2086 deletions

View File

@@ -54,5 +54,7 @@ def test_preprocess(df, class_to_index):
def test_fit_transform(dataset_loc, preprocessor):
ds = data.load_data(dataset_loc=dataset_loc)
preprocessor.fit_transform(ds)
preprocessor = preprocessor.fit(ds)
preprocessed_ds = preprocessor.transform(ds)
assert len(preprocessor.class_to_index) == 4
assert ds.count() == preprocessed_ds.count()

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import numpy as np
import pytest
import torch
from ray.train.torch import get_device
from madewithml import utils
@@ -42,9 +43,9 @@ def test_collate_fn():
}
processed_batch = utils.collate_fn(batch)
expected_batch = {
"ids": torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.int32),
"masks": torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int32),
"targets": torch.tensor([3, 1], dtype=torch.int64),
"ids": torch.as_tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.int32, device=get_device()),
"masks": torch.as_tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int32, device=get_device()),
"targets": torch.as_tensor([3, 1], dtype=torch.int64, device=get_device()),
}
for k in batch:
assert torch.allclose(processed_batch[k], expected_batch[k])

View File

@@ -1,7 +1,7 @@
import pytest
from ray.train.torch.torch_predictor import TorchPredictor
from madewithml import predict
from madewithml.predict import TorchPredictor
def pytest_addoption(parser):

View File

@@ -1,12 +1,9 @@
import numpy as np
import pandas as pd
import ray
from madewithml import predict
def get_label(text, predictor):
df = pd.DataFrame({"title": [text], "description": "", "tag": "other"})
z = predictor.predict(data=df)["predictions"]
preprocessor = predictor.get_preprocessor()
label = predict.decode(np.stack(z).argmax(1), preprocessor.index_to_class)[0]
return label
sample_ds = ray.data.from_items([{"title": text, "description": "", "tag": "other"}])
results = predict.predict_proba(ds=sample_ds, predictor=predictor)
return results[0]["prediction"]