first commit
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
# 1. Load your fine-tuned model and tokenizer
|
||||
model_path = "./byt5-taglish-nli-final-v3"
|
||||
print(f"Loading model from {model_path}...")
|
||||
|
||||
# local_files_only=True bypasses the cluster's DNS issues
|
||||
tokenizer = ByT5Tokenizer.from_pretrained(model_path, local_files_only=True)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_path, local_files_only=True)
|
||||
|
||||
# 2. Load the first 10 samples from the benchmark dataset
|
||||
print("Loading benchmark dataset...")
|
||||
df_test = pd.read_csv("benchmark_dataset.csv")
|
||||
sample_df = df_test.head(50)
|
||||
|
||||
# 3. Generate and decode predictions
|
||||
print("\n" + "="*50)
|
||||
print("RAW MODEL PREDICTIONS")
|
||||
print("="*50)
|
||||
|
||||
label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
|
||||
|
||||
for index, row in sample_df.iterrows():
|
||||
# Recreate the exact prompt format used during training
|
||||
# input_text = f"nli premise: {row['s1']} hypothesis: {row['s2']}"
|
||||
input_text = f"Context: {row['s1']} Statement: {row['s2']} Question: Does the context entail, contradict, or remain neutral to the statement? Answer:"
|
||||
|
||||
# Tokenize the input
|
||||
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
||||
|
||||
# Generate the prediction bytes
|
||||
# max_new_tokens=25 is added to fix the warning you saw earlier.
|
||||
# "contradiction" takes roughly 13 byte-tokens in ByT5.
|
||||
outputs = model.generate(**inputs, max_new_tokens=25)
|
||||
|
||||
# Decode the raw bytes back into readable text
|
||||
raw_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
true_label = label_map.get(row['label'], "unknown")
|
||||
|
||||
print(f"\nSample {index + 1}:")
|
||||
print(f"Input: {input_text}")
|
||||
print(f"Expected: '{true_label}'")
|
||||
print(f"Model Says: '{raw_prediction}'")
|
||||
Reference in New Issue
Block a user