import pandas as pd import torch from transformers import ByT5Tokenizer, T5ForConditionalGeneration from sklearn.metrics import accuracy_score, f1_score from tqdm import tqdm # 1. Load your frozen, peak model model_path = "./byt5-taglish-nli-final-v3" print(f"Loading model from {model_path}...") tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True) model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True) # Push to your Tesla V100 GPU for fast inference device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # 2. Load the NewsPH-NLI dataset # (Make sure to update the filename and column names if they differ) print("Loading NewsPH-NLI dataset...") df_newsph = pd.read_csv("newsph_nli.csv") true_labels = [] pred_labels = [] print("Generating predictions...") # tqdm gives you a nice progress bar so you aren't staring at a blank screen for index, row in tqdm(df_newsph.iterrows(), total=len(df_newsph)): # Use the exact 3-class conversational prompt your model trained on input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:" inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device) # Generate raw text prediction outputs = model.generate(**inputs, max_new_tokens=25) raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower() # ========================================== # 3. The Binary Collapse Logic # ========================================== if "entailment" in raw_prediction: binary_pred = 0 else: # If it says 'neutral', 'contradiction', or hallucinates, it falls under "Not Entailment" (1) binary_pred = 1 true_labels.append(row['label']) pred_labels.append(binary_pred) # 4. Calculate Final Metrics acc = accuracy_score(true_labels, pred_labels) f1 = f1_score(true_labels, pred_labels, average='macro') print("\n" + "="*50) print("NewsPH-NLI Binary Evaluation Results") print("="*50) print(f"Accuracy: {acc * 100:.2f}%") print(f"F1 Macro: {f1:.4f}")