updated to Ray 2.7

This commit is contained in:
GokuMohandas
2023-09-18 22:03:20 -07:00
parent 71b3d50a05
commit b98bd5b1ae
15 changed files with 3484 additions and 2086 deletions

View File

@@ -5,7 +5,6 @@ import numpy as np
import pandas as pd
import ray
from ray.data import Dataset
from ray.data.preprocessor import Preprocessor
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
@@ -135,13 +134,18 @@ def preprocess(df: pd.DataFrame, class_to_index: Dict) -> Dict:
return outputs
class CustomPreprocessor(Preprocessor):
class CustomPreprocessor:
"""Custom preprocessor class."""
def _fit(self, ds):
def __init__(self, class_to_index={}):
self.class_to_index = class_to_index or {} # mutable defaults
self.index_to_class = {v: k for k, v in self.class_to_index.items()}
def fit(self, ds):
tags = ds.unique(column="tag")
self.class_to_index = {tag: i for i, tag in enumerate(tags)}
self.index_to_class = {v: k for k, v in self.class_to_index.items()}
return self
def _transform_pandas(self, batch): # could also do _transform_numpy
return preprocess(batch, class_to_index=self.class_to_index)
def transform(self, ds):
return ds.map_batches(preprocess, fn_kwargs={"class_to_index": self.class_to_index}, batch_format="pandas")