Files
dsp15/eval_newsph_small.py
2026-05-21 01:21:51 +08:00

60 lines
2.2 KiB
Python

import pandas as pd
import torch
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
# 1. Load your frozen, peak model
model_path = "./byt5-taglish-nli-final-v2"
print(f"Loading model from {model_path}...")
tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True)
# Push to your Tesla V100 GPU for fast inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 2. Load the NewsPH-NLI dataset
# (Make sure to update the filename and column names if they differ)
print("Loading NewsPH-NLI dataset...")
df_newsph = pd.read_csv("newsph_nli.csv")
true_labels = []
pred_labels = []
print("Generating predictions...")
# tqdm gives you a nice progress bar so you aren't staring at a blank screen
for index, row in tqdm(df_newsph.iterrows(), total=len(df_newsph)):
# Use the exact 3-class conversational prompt your model trained on
input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
# Generate raw text prediction
outputs = model.generate(**inputs, max_new_tokens=25)
raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
# ==========================================
# 3. The Binary Collapse Logic
# ==========================================
if "entailment" in raw_prediction:
binary_pred = 0
else:
# If it says 'neutral', 'contradiction', or hallucinates, it falls under "Not Entailment" (1)
binary_pred = 1
true_labels.append(row['label'])
pred_labels.append(binary_pred)
# 4. Calculate Final Metrics
acc = accuracy_score(true_labels, pred_labels)
f1 = f1_score(true_labels, pred_labels, average='macro')
print("\n" + "="*50)
print("NewsPH-NLI Binary Evaluation Results")
print("="*50)
print(f"Accuracy: {acc * 100:.2f}%")
print(f"F1 Macro: {f1:.4f}")