Made-With-ML/madewithml/models.py

20 lines
657 B
Python

import torch
import torch.nn as nn
class FinetunedLLM(nn.Module): # pragma: no cover, torch model
"""Model architecture for a Large Language Model (LLM) that we will fine-tune."""
def __init__(self, llm, dropout_p, embedding_dim, num_classes):
super(FinetunedLLM, self).__init__()
self.llm = llm
self.dropout = torch.nn.Dropout(dropout_p)
self.fc1 = torch.nn.Linear(embedding_dim, num_classes)
def forward(self, batch):
ids, masks = batch["ids"], batch["masks"]
seq, pool = self.llm(input_ids=ids, attention_mask=masks)
z = self.dropout(pool)
z = self.fc1(z)
return z