47 lines
1.8 KiB
Python
47 lines
1.8 KiB
Python
import pandas as pd
|
|
import torch
|
|
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
|
|
|
# 1. Load your fine-tuned model and tokenizer
|
|
model_path = "./byt5-taglish-nli-final-v3"
|
|
print(f"Loading model from {model_path}...")
|
|
|
|
# local_files_only=True bypasses the cluster's DNS issues
|
|
tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True)
|
|
model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True)
|
|
|
|
# 2. Load the first 10 samples from the benchmark dataset
|
|
print("Loading benchmark dataset...")
|
|
df_test = pd.read_csv("benchmark_dataset.csv")
|
|
sample_df = df_test.head(50)
|
|
|
|
# 3. Generate and decode predictions
|
|
print("\n" + "="*50)
|
|
print("RAW MODEL PREDICTIONS")
|
|
print("="*50)
|
|
|
|
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
|
|
|
|
for index, row in sample_df.iterrows():
|
|
# Recreate the exact prompt format used during training
|
|
# input_text = f"nli premise: {row['s1']} hypothesis: {row['s2']}"
|
|
input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:"
|
|
|
|
# Tokenize the input
|
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
|
|
|
# Generate the prediction bytes
|
|
# max_new_tokens=25 is added to fix the warning you saw earlier.
|
|
# "contradiction" takes roughly 13 byte-tokens in ByT5.
|
|
outputs = model.generate(**inputs, max_new_tokens=25)
|
|
|
|
# Decode the raw bytes back into readable text
|
|
raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
true_label = label_map.get(row['label'], "unknown")
|
|
|
|
print(f"\nSample {index + 1}:")
|
|
print(f"Input: {input_text}")
|
|
print(f"Expected: '{true_label}'")
|
|
print(f"Model Says: '{raw_prediction}'")
|