Files
dsp15/inspect_preds.py
T
2026-05-21 01:21:51 +08:00

47 lines
1.8 KiB
Python

import pandas as pd
import torch
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
# 1. Load your fine-tuned model and tokenizer
model_path = "./byt5-taglish-nli-final-v3"
print(f"Loading model from {model_path}...")
# local_files_only=True bypasses the cluster's DNS issues
tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True)
# 2. Load the first 10 samples from the benchmark dataset
print("Loading benchmark dataset...")
df_test = pd.read_csv("benchmark_dataset.csv")
sample_df = df_test.head(50)
# 3. Generate and decode predictions
print("\n" + "="*50)
print("RAW MODEL PREDICTIONS")
print("="*50)
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
for index, row in sample_df.iterrows():
# Recreate the exact prompt format used during training
# input_text = f"nli premise: {row['s1']} hypothesis: {row['s2']}"
input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:"
# Tokenize the input
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate the prediction bytes
# max_new_tokens=25 is added to fix the warning you saw earlier.
# "contradiction" takes roughly 13 byte-tokens in ByT5.
outputs = model.generate(**inputs, max_new_tokens=25)
# Decode the raw bytes back into readable text
raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
true_label = label_map.get(row['label'], "unknown")
print(f"\nSample {index + 1}:")
print(f"Input: {input_text}")
print(f"Expected: '{true_label}'")
print(f"Model Says: '{raw_prediction}'")