60 lines
2.2 KiB
Python
60 lines
2.2 KiB
Python
import pandas as pd
|
|
import torch
|
|
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
|
from sklearn.metrics import accuracy_score, f1_score
|
|
from tqdm import tqdm
|
|
|
|
# 1. Load your frozen, peak model
|
|
model_path = "./byt5-taglish-nli-final-v3"
|
|
print(f"Loading model from {model_path}...")
|
|
|
|
tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True)
|
|
model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True)
|
|
|
|
# Push to your Tesla V100 GPU for fast inference
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
# 2. Load the NewsPH-NLI dataset
|
|
# (Make sure to update the filename and column names if they differ)
|
|
print("Loading NewsPH-NLI dataset...")
|
|
df_newsph = pd.read_csv("newsph_nli.csv")
|
|
|
|
true_labels = []
|
|
pred_labels = []
|
|
|
|
print("Generating predictions...")
|
|
# tqdm gives you a nice progress bar so you aren't staring at a blank screen
|
|
for index, row in tqdm(df_newsph.iterrows(), total=len(df_newsph)):
|
|
|
|
# Use the exact 3-class conversational prompt your model trained on
|
|
input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:"
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
|
|
|
# Generate raw text prediction
|
|
outputs = model.generate(**inputs, max_new_tokens=25)
|
|
raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
|
|
|
|
# ==========================================
|
|
# 3. The Binary Collapse Logic
|
|
# ==========================================
|
|
if "entailment" in raw_prediction:
|
|
binary_pred = 0
|
|
else:
|
|
# If it says 'neutral', 'contradiction', or hallucinates, it falls under "Not Entailment" (1)
|
|
binary_pred = 1
|
|
|
|
true_labels.append(row['label'])
|
|
pred_labels.append(binary_pred)
|
|
|
|
# 4. Calculate Final Metrics
|
|
acc = accuracy_score(true_labels, pred_labels)
|
|
f1 = f1_score(true_labels, pred_labels, average='macro')
|
|
|
|
print("\n" + "="*50)
|
|
print("NewsPH-NLI Binary Evaluation Results")
|
|
print("="*50)
|
|
print(f"Accuracy: {acc * 100:.2f}%")
|
|
print(f"F1 Macro: {f1:.4f}")
|