first commit
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
from tqdm import tqdm
|
||||
|
||||
# 1. Load your frozen, peak model
|
||||
model_path = "./byt5-taglish-nli-final-v3"
|
||||
print(f"Loading model from {model_path}...")
|
||||
|
||||
tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True)
|
||||
|
||||
# Push to your Tesla V100 GPU for fast inference
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
# 2. Load the NewsPH-NLI dataset
|
||||
# (Make sure to update the filename and column names if they differ)
|
||||
print("Loading NewsPH-NLI dataset...")
|
||||
df_newsph = pd.read_csv("newsph_nli.csv")
|
||||
|
||||
true_labels = []
|
||||
pred_labels = []
|
||||
|
||||
print("Generating predictions...")
|
||||
# tqdm gives you a nice progress bar so you aren't staring at a blank screen
|
||||
for index, row in tqdm(df_newsph.iterrows(), total=len(df_newsph)):
|
||||
|
||||
# Use the exact 3-class conversational prompt your model trained on
|
||||
input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:"
|
||||
|
||||
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)
|
||||
|
||||
# Generate raw text prediction
|
||||
outputs = model.generate(**inputs, max_new_tokens=25)
|
||||
raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
|
||||
|
||||
# ==========================================
|
||||
# 3. The Binary Collapse Logic
|
||||
# ==========================================
|
||||
if "entailment" in raw_prediction:
|
||||
binary_pred = 0
|
||||
else:
|
||||
# If it says 'neutral', 'contradiction', or hallucinates, it falls under "Not Entailment" (1)
|
||||
binary_pred = 1
|
||||
|
||||
true_labels.append(row['label'])
|
||||
pred_labels.append(binary_pred)
|
||||
|
||||
# 4. Calculate Final Metrics
|
||||
acc = accuracy_score(true_labels, pred_labels)
|
||||
f1 = f1_score(true_labels, pred_labels, average='macro')
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("NewsPH-NLI Binary Evaluation Results")
|
||||
print("="*50)
|
||||
print(f"Accuracy: {acc * 100:.2f}%")
|
||||
print(f"F1 Macro: {f1:.4f}")
|
||||
Reference in New Issue
Block a user