Merge pull request #239 from GokuMohandas/dev
adding dotenv for credentials management
This commit is contained in:
commit
71b3d50a05
42
README.md
42
README.md
@ -101,6 +101,17 @@ We'll start by setting up our cluster with the environment and compute configura
|
||||
|
||||
</details>
|
||||
|
||||
### Credentials
|
||||
```bash
|
||||
touch .env
|
||||
```
|
||||
```bash
|
||||
# Inside .env
|
||||
GITHUB_USERNAME="CHANGE_THIS_TO_YOUR_USERNAME" # ← CHANGE THIS
|
||||
```bash
|
||||
source .env
|
||||
```
|
||||
|
||||
### Git setup
|
||||
|
||||
Create a repository by following these instructions: [Create a new repository](https://github.com/new) → name it `Made-With-ML` → Toggle `Add a README file` (**very important** as this creates a `main` branch) → Click `Create repository` (scroll down)
|
||||
@ -109,7 +120,7 @@ Now we're ready to clone the repository that has all of our code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/GokuMohandas/Made-With-ML.git .
|
||||
git remote set-url origin https://github.com/GITHUB_USERNAME/Made-With-ML.git # <-- CHANGE THIS to your username
|
||||
git remote set-url origin https://github.com/$GITHUB_USERNAME/Made-With-ML.git # <-- CHANGE THIS to your username
|
||||
git checkout -b dev
|
||||
```
|
||||
|
||||
@ -317,15 +328,7 @@ python madewithml/predict.py predict \
|
||||
python madewithml/serve.py --run_id $RUN_ID
|
||||
```
|
||||
|
||||
While the application is running, we can use it via cURL, Python, etc.:
|
||||
|
||||
```bash
|
||||
# via cURL
|
||||
curl -X POST -H "Content-Type: application/json" -d '{
|
||||
"title": "Transfer learning with transformers",
|
||||
"description": "Using transformers for transfer learning on text classification tasks."
|
||||
}' http://127.0.0.1:8000/predict
|
||||
```
|
||||
Once the application is running, we can use it via cURL, Python, etc.:
|
||||
|
||||
```python
|
||||
# via Python
|
||||
@ -341,13 +344,6 @@ python madewithml/predict.py predict \
|
||||
ray stop # shutdown
|
||||
```
|
||||
|
||||
```bash
|
||||
export HOLDOUT_LOC="https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/holdout.csv"
|
||||
curl -X POST -H "Content-Type: application/json" -d '{
|
||||
"dataset_loc": "https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/holdout.csv"
|
||||
}' http://127.0.0.1:8000/evaluate
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details open>
|
||||
@ -362,15 +358,7 @@ curl -X POST -H "Content-Type: application/json" -d '{
|
||||
python madewithml/serve.py --run_id $RUN_ID
|
||||
```
|
||||
|
||||
While the application is running, we can use it via cURL, Python, etc.:
|
||||
|
||||
```bash
|
||||
# via cURL
|
||||
curl -X POST -H "Content-Type: application/json" -d '{
|
||||
"title": "Transfer learning with transformers",
|
||||
"description": "Using transformers for transfer learning on text classification tasks."
|
||||
}' http://127.0.0.1:8000/predict
|
||||
```
|
||||
Once the application is running, we can use it via cURL, Python, etc.:
|
||||
|
||||
```python
|
||||
# via Python
|
||||
@ -399,7 +387,7 @@ export RUN_ID=$(python madewithml/predict.py get-best-run-id --experiment-name $
|
||||
pytest --run-id=$RUN_ID tests/model --verbose --disable-warnings
|
||||
|
||||
# Coverage
|
||||
python3 -m pytest --cov madewithml --cov-report html
|
||||
python3 -m pytest tests/code --cov madewithml --cov-report html --disable-warnings
|
||||
```
|
||||
|
||||
## Production
|
||||
|
3
madewithml/__init__.py
Normal file
3
madewithml/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
@ -1,5 +1,6 @@
|
||||
# config.py
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@ -10,9 +11,10 @@ import pretty_errors # NOQA: F401 (imported but unused)
|
||||
ROOT_DIR = Path(__file__).parent.parent.absolute()
|
||||
LOGS_DIR = Path(ROOT_DIR, "logs")
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
EFS_DIR = Path(f"/efs/shared_storage/madewithml/{os.environ.get('GITHUB_USERNAME', '')}")
|
||||
|
||||
# Config MLflow
|
||||
MODEL_REGISTRY = Path("/tmp/mlflow")
|
||||
MODEL_REGISTRY = Path(f"{EFS_DIR}/mlflow")
|
||||
Path(MODEL_REGISTRY).mkdir(parents=True, exist_ok=True)
|
||||
MLFLOW_TRACKING_URI = "file://" + str(MODEL_REGISTRY.absolute())
|
||||
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
|
||||
|
@ -125,7 +125,7 @@ def predict(
|
||||
# Load components
|
||||
best_checkpoint = get_best_checkpoint(run_id=run_id)
|
||||
predictor = TorchPredictor.from_checkpoint(best_checkpoint)
|
||||
preprocessor = predictor.get_preprocessor()
|
||||
# preprocessor = predictor.get_preprocessor()
|
||||
|
||||
# Predict
|
||||
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "other"}])
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
|
||||
@ -75,5 +76,5 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--run_id", help="run ID to use for serving.")
|
||||
parser.add_argument("--threshold", type=float, default=0.9, help="threshold for `other` class.")
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
ray.init(runtime_env={"env_vars": {"GITHUB_USERNAME": os.environ["GITHUB_USERNAME"]}})
|
||||
serve.run(ModelDeployment.bind(run_id=args.run_id, threshold=args.threshold))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -23,7 +24,7 @@ from transformers import BertModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from madewithml import data, models, utils
|
||||
from madewithml.config import MLFLOW_TRACKING_URI, logger
|
||||
from madewithml.config import EFS_DIR, MLFLOW_TRACKING_URI, logger
|
||||
|
||||
# Initialize Typer CLI app
|
||||
app = typer.Typer()
|
||||
@ -200,10 +201,7 @@ def train_model(
|
||||
)
|
||||
|
||||
# Run config
|
||||
run_config = RunConfig(
|
||||
callbacks=[mlflow_callback],
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
run_config = RunConfig(callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, storage_path=EFS_DIR)
|
||||
|
||||
# Dataset
|
||||
ds = data.load_data(dataset_loc=dataset_loc, num_samples=train_loop_config["num_samples"])
|
||||
@ -252,5 +250,5 @@ def train_model(
|
||||
if __name__ == "__main__": # pragma: no cover, application
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
ray.init(runtime_env={"env_vars": {"GITHUB_USERNAME": os.environ["GITHUB_USERNAME"]}})
|
||||
app()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import typer
|
||||
@ -19,7 +20,7 @@ from ray.tune.search.hyperopt import HyperOptSearch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from madewithml import data, train, utils
|
||||
from madewithml.config import MLFLOW_TRACKING_URI, logger
|
||||
from madewithml.config import EFS_DIR, MLFLOW_TRACKING_URI, logger
|
||||
|
||||
# Initialize Typer CLI app
|
||||
app = typer.Typer()
|
||||
@ -117,10 +118,7 @@ def tune_models(
|
||||
experiment_name=experiment_name,
|
||||
save_artifact=True,
|
||||
)
|
||||
run_config = RunConfig(
|
||||
callbacks=[mlflow_callback],
|
||||
checkpoint_config=checkpoint_config,
|
||||
)
|
||||
run_config = RunConfig(callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, storage_path=EFS_DIR)
|
||||
|
||||
# Hyperparameters to start with
|
||||
initial_params = json.loads(initial_params)
|
||||
@ -178,5 +176,5 @@ def tune_models(
|
||||
if __name__ == "__main__": # pragma: no cover, application
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
ray.init(runtime_env={"env_vars": {"GITHUB_USERNAME": os.environ["GITHUB_USERNAME"]}})
|
||||
app()
|
||||
|
File diff suppressed because one or more lines are too long
@ -8,6 +8,7 @@ numpy==1.24.3
|
||||
numpyencoder==0.3.0
|
||||
pandas==2.0.1
|
||||
pretty-errors==1.2.25
|
||||
python-dotenv==1.0.0
|
||||
ray[air]==2.6.0
|
||||
scikit-learn==1.2.2
|
||||
snorkel==0.9.9
|
||||
|
Loading…
Reference in New Issue
Block a user