20 lines
657 B
Python
20 lines
657 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class FinetunedLLM(nn.Module): # pragma: no cover, torch model
|
|
"""Model architecture for a Large Language Model (LLM) that we will fine-tune."""
|
|
|
|
def __init__(self, llm, dropout_p, embedding_dim, num_classes):
|
|
super(FinetunedLLM, self).__init__()
|
|
self.llm = llm
|
|
self.dropout = torch.nn.Dropout(dropout_p)
|
|
self.fc1 = torch.nn.Linear(embedding_dim, num_classes)
|
|
|
|
def forward(self, batch):
|
|
ids, masks = batch["ids"], batch["masks"]
|
|
seq, pool = self.llm(input_ids=ids, attention_mask=masks)
|
|
z = self.dropout(pool)
|
|
z = self.fc1(z)
|
|
return z
|