ML for Developers
This commit is contained in:
20
tests/model/conftest.py
Normal file
20
tests/model/conftest.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from ray.train.torch.torch_predictor import TorchPredictor
|
||||
|
||||
from madewithml import predict
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--run-id", action="store", default=None, help="Run ID of model to use.")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def run_id(request):
|
||||
return request.config.getoption("--run-id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def predictor(run_id):
|
||||
best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
|
||||
predictor = TorchPredictor.from_checkpoint(best_checkpoint)
|
||||
return predictor
|
||||
Reference in New Issue
Block a user