ML for Developers
This commit is contained in:
20
tests/model/conftest.py
Normal file
20
tests/model/conftest.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from ray.train.torch.torch_predictor import TorchPredictor
|
||||
|
||||
from madewithml import predict
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--run-id", action="store", default=None, help="Run ID of model to use.")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def run_id(request):
|
||||
return request.config.getoption("--run-id")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def predictor(run_id):
|
||||
best_checkpoint = predict.get_best_checkpoint(run_id=run_id)
|
||||
predictor = TorchPredictor.from_checkpoint(best_checkpoint)
|
||||
return predictor
|
||||
65
tests/model/test_behavioral.py
Normal file
65
tests/model/test_behavioral.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
import utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_a, input_b, label",
|
||||
[
|
||||
(
|
||||
"Transformers applied to NLP have revolutionized machine learning.",
|
||||
"Transformers applied to NLP have disrupted machine learning.",
|
||||
"natural-language-processing",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invariance(input_a, input_b, label, predictor):
|
||||
"""INVariance via verb injection (changes should not affect outputs)."""
|
||||
label_a = utils.get_label(text=input_a, predictor=predictor)
|
||||
label_b = utils.get_label(text=input_b, predictor=predictor)
|
||||
assert label_a == label_b == label
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input, label",
|
||||
[
|
||||
(
|
||||
"ML applied to text classification.",
|
||||
"natural-language-processing",
|
||||
),
|
||||
(
|
||||
"ML applied to image classification.",
|
||||
"computer-vision",
|
||||
),
|
||||
(
|
||||
"CNNs for text classification.",
|
||||
"natural-language-processing",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_directional(input, label, predictor):
|
||||
"""DIRectional expectations (changes with known outputs)."""
|
||||
prediction = utils.get_label(text=input, predictor=predictor)
|
||||
assert label == prediction
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input, label",
|
||||
[
|
||||
(
|
||||
"Natural language processing is the next big wave in machine learning.",
|
||||
"natural-language-processing",
|
||||
),
|
||||
(
|
||||
"MLOps is the next big wave in machine learning.",
|
||||
"mlops",
|
||||
),
|
||||
(
|
||||
"This is about graph neural networks.",
|
||||
"other",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mft(input, label, predictor):
|
||||
"""Minimum Functionality Tests (simple input/output pairs)."""
|
||||
prediction = utils.get_label(text=input, predictor=predictor)
|
||||
assert label == prediction
|
||||
12
tests/model/utils.py
Normal file
12
tests/model/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from madewithml import predict
|
||||
|
||||
|
||||
def get_label(text, predictor):
|
||||
df = pd.DataFrame({"title": [text], "description": "", "tag": "other"})
|
||||
z = predictor.predict(data=df)["predictions"]
|
||||
preprocessor = predictor.get_preprocessor()
|
||||
label = predict.decode(np.stack(z).argmax(1), preprocessor.index_to_class)[0]
|
||||
return label
|
||||
Reference in New Issue
Block a user