updated to Ray 2.7
This commit is contained in:
@@ -5,7 +5,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
from ray.data import Dataset
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
from sklearn.model_selection import train_test_split
|
||||
from transformers import BertTokenizer
|
||||
|
||||
@@ -135,13 +134,18 @@ def preprocess(df: pd.DataFrame, class_to_index: Dict) -> Dict:
|
||||
return outputs
|
||||
|
||||
|
||||
class CustomPreprocessor(Preprocessor):
|
||||
class CustomPreprocessor:
|
||||
"""Custom preprocessor class."""
|
||||
|
||||
def _fit(self, ds):
|
||||
def __init__(self, class_to_index={}):
|
||||
self.class_to_index = class_to_index or {} # mutable defaults
|
||||
self.index_to_class = {v: k for k, v in self.class_to_index.items()}
|
||||
|
||||
def fit(self, ds):
|
||||
tags = ds.unique(column="tag")
|
||||
self.class_to_index = {tag: i for i, tag in enumerate(tags)}
|
||||
self.index_to_class = {v: k for k, v in self.class_to_index.items()}
|
||||
return self
|
||||
|
||||
def _transform_pandas(self, batch): # could also do _transform_numpy
|
||||
return preprocess(batch, class_to_index=self.class_to_index)
|
||||
def transform(self, ds):
|
||||
return ds.map_batches(preprocess, fn_kwargs={"class_to_index": self.class_to_index}, batch_format="pandas")
|
||||
|
||||
Reference in New Issue
Block a user