updated to Ray 2.7

This commit is contained in:
GokuMohandas
2023-09-18 22:03:20 -07:00
parent 71b3d50a05
commit b98bd5b1ae
15 changed files with 3484 additions and 2086 deletions

View File

@@ -1,6 +1,7 @@
import datetime
import json
import os
import tempfile
from typing import Tuple
import numpy as np
@@ -10,21 +11,23 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import typer
from ray.air import session
from ray.air.config import (
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.data import Dataset
from ray.train import (
Checkpoint,
CheckpointConfig,
DatasetConfig,
DataConfig,
RunConfig,
ScalingConfig,
)
from ray.air.integrations.mlflow import MLflowLoggerCallback
from ray.data import Dataset
from ray.train.torch import TorchCheckpoint, TorchTrainer
from ray.train.torch import TorchTrainer
from torch.nn.parallel.distributed import DistributedDataParallel
from transformers import BertModel
from typing_extensions import Annotated
from madewithml import data, models, utils
from madewithml import data, utils
from madewithml.config import EFS_DIR, MLFLOW_TRACKING_URI, logger
from madewithml.models import FinetunedLLM
# Initialize Typer CLI app
app = typer.Typer()
@@ -106,18 +109,18 @@ def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via
lr = config["lr"]
lr_factor = config["lr_factor"]
lr_patience = config["lr_patience"]
batch_size = config["batch_size"]
num_epochs = config["num_epochs"]
batch_size = config["batch_size"]
num_classes = config["num_classes"]
# Get datasets
utils.set_seeds()
train_ds = session.get_dataset_shard("train")
val_ds = session.get_dataset_shard("val")
train_ds = train.get_dataset_shard("train")
val_ds = train.get_dataset_shard("val")
# Model
llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
model = models.FinetunedLLM(llm=llm, dropout_p=dropout_p, embedding_dim=llm.config.hidden_size, num_classes=num_classes)
model = FinetunedLLM(llm=llm, dropout_p=dropout_p, embedding_dim=llm.config.hidden_size, num_classes=num_classes)
model = train.torch.prepare_model(model)
# Training components
@@ -126,7 +129,8 @@ def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=lr_factor, patience=lr_patience)
# Training
batch_size_per_worker = batch_size // session.get_world_size()
num_workers = train.get_context().get_world_size()
batch_size_per_worker = batch_size // num_workers
for epoch in range(num_epochs):
# Step
train_loss = train_step(train_ds, batch_size_per_worker, model, num_classes, loss_fn, optimizer)
@@ -134,9 +138,14 @@ def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via
scheduler.step(val_loss)
# Checkpoint
metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
checkpoint = TorchCheckpoint.from_model(model=model)
session.report(metrics, checkpoint=checkpoint)
with tempfile.TemporaryDirectory() as dp:
if isinstance(model, DistributedDataParallel): # cpu
model.module.save(dp=dp)
else:
model.save(dp=dp)
metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
checkpoint = Checkpoint.from_directory(dp)
train.report(metrics, checkpoint=checkpoint)
@app.command()
@@ -183,7 +192,6 @@ def train_model(
num_workers=num_workers,
use_gpu=bool(gpu_per_worker),
resources_per_worker={"CPU": cpu_per_worker, "GPU": gpu_per_worker},
_max_cpu_fraction_per_node=0.8,
)
# Checkpoint config
@@ -201,7 +209,7 @@ def train_model(
)
# Run config
run_config = RunConfig(callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, storage_path=EFS_DIR)
run_config = RunConfig(callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, storage_path=EFS_DIR, local_dir=EFS_DIR)
# Dataset
ds = data.load_data(dataset_loc=dataset_loc, num_samples=train_loop_config["num_samples"])
@@ -210,14 +218,13 @@ def train_model(
train_loop_config["num_classes"] = len(tags)
# Dataset config
dataset_config = {
"train": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
"val": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
}
options = ray.data.ExecutionOptions(preserve_order=True)
dataset_config = DataConfig(datasets_to_split=["train"], execution_options=options)
# Preprocess
preprocessor = data.CustomPreprocessor()
train_ds = preprocessor.fit_transform(train_ds)
preprocessor = preprocessor.fit(train_ds)
train_ds = preprocessor.transform(train_ds)
val_ds = preprocessor.transform(val_ds)
train_ds = train_ds.materialize()
val_ds = val_ds.materialize()
@@ -230,7 +237,7 @@ def train_model(
run_config=run_config,
datasets={"train": train_ds, "val": val_ds},
dataset_config=dataset_config,
preprocessor=preprocessor,
metadata={"class_to_index": preprocessor.class_to_index},
)
# Train