60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import BertModel
|
|
|
|
|
|
class FinetunedLLM(nn.Module):
|
|
def __init__(self, llm, dropout_p, embedding_dim, num_classes):
|
|
super(FinetunedLLM, self).__init__()
|
|
self.llm = llm
|
|
self.dropout_p = dropout_p
|
|
self.embedding_dim = embedding_dim
|
|
self.num_classes = num_classes
|
|
self.dropout = torch.nn.Dropout(dropout_p)
|
|
self.fc1 = torch.nn.Linear(embedding_dim, num_classes)
|
|
|
|
def forward(self, batch):
|
|
ids, masks = batch["ids"], batch["masks"]
|
|
seq, pool = self.llm(input_ids=ids, attention_mask=masks)
|
|
z = self.dropout(pool)
|
|
z = self.fc1(z)
|
|
return z
|
|
|
|
@torch.inference_mode()
|
|
def predict(self, batch):
|
|
self.eval()
|
|
z = self(batch)
|
|
y_pred = torch.argmax(z, dim=1).cpu().numpy()
|
|
return y_pred
|
|
|
|
@torch.inference_mode()
|
|
def predict_proba(self, batch):
|
|
self.eval()
|
|
z = self(batch)
|
|
y_probs = F.softmax(z, dim=1).cpu().numpy()
|
|
return y_probs
|
|
|
|
def save(self, dp):
|
|
with open(Path(dp, "args.json"), "w") as fp:
|
|
contents = {
|
|
"dropout_p": self.dropout_p,
|
|
"embedding_dim": self.embedding_dim,
|
|
"num_classes": self.num_classes,
|
|
}
|
|
json.dump(contents, fp, indent=4, sort_keys=False)
|
|
torch.save(self.state_dict(), os.path.join(dp, "model.pt"))
|
|
|
|
@classmethod
|
|
def load(cls, args_fp, state_dict_fp):
|
|
with open(args_fp, "r") as fp:
|
|
kwargs = json.load(fp=fp)
|
|
llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
|
|
model = cls(llm=llm, **kwargs)
|
|
model.load_state_dict(torch.load(state_dict_fp, map_location=torch.device("cpu")))
|
|
return model
|