updated to Ray 2.7
This commit is contained in:
@@ -54,5 +54,7 @@ def test_preprocess(df, class_to_index):
|
||||
|
||||
def test_fit_transform(dataset_loc, preprocessor):
|
||||
ds = data.load_data(dataset_loc=dataset_loc)
|
||||
preprocessor.fit_transform(ds)
|
||||
preprocessor = preprocessor.fit(ds)
|
||||
preprocessed_ds = preprocessor.transform(ds)
|
||||
assert len(preprocessor.class_to_index) == 4
|
||||
assert ds.count() == preprocessed_ds.count()
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from ray.train.torch import get_device
|
||||
|
||||
from madewithml import utils
|
||||
|
||||
@@ -42,9 +43,9 @@ def test_collate_fn():
|
||||
}
|
||||
processed_batch = utils.collate_fn(batch)
|
||||
expected_batch = {
|
||||
"ids": torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.int32),
|
||||
"masks": torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int32),
|
||||
"targets": torch.tensor([3, 1], dtype=torch.int64),
|
||||
"ids": torch.as_tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.int32, device=get_device()),
|
||||
"masks": torch.as_tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int32, device=get_device()),
|
||||
"targets": torch.as_tensor([3, 1], dtype=torch.int64, device=get_device()),
|
||||
}
|
||||
for k in batch:
|
||||
assert torch.allclose(processed_batch[k], expected_batch[k])
|
||||
|
||||
Reference in New Issue
Block a user