Source code for roksana.evaluation.evaluator

# roksana/evaluation/evaluator.py

from typing import List, Tuple, Dict, Any
import csv
import json
import os
import torch
from .metrics import hit_at_k, recall_at_k, demotion_value

[docs] class Evaluator: """ Evaluator class to assess the impact of attack methods on search strategies. """
[docs] def __init__(self, search_method_before, search_method_after, k_values: List[int] = [5, 10, 20]): """ Initialize the Evaluator. Args: search_method_before: Instance of SearchMethod before attack. search_method_after: Instance of SearchMethod after attack. k_values (List[int], optional): List of k values for Hit@k and Recall@k. Defaults to [5, 10, 20]. """ self.search_before = search_method_before self.search_after = search_method_after self.k_values = k_values self.results = [] # Store results internally
[docs] def evaluate( self, queries: List[int], gold_sets: List[List[int]], results_dir: str = 'results', filename: str = 'evaluation_results.csv' ) -> None: """ Perform evaluation on the given queries and save the results. Args: queries (List[int]): List of query node indices. gold_sets (List[List[int]]): List of gold sets corresponding to each query. results_dir (str, optional): Directory to save the results file. Defaults to 'results'. filename (str, optional): Name of the results file. Defaults to 'evaluation_results.csv'. """ if not os.path.exists(results_dir): os.makedirs(results_dir) filepath = os.path.join(results_dir, filename) # Prepare header header = [ 'query_node', 'k', 'Hit@k_before_attack', 'Hit@k_after_attack', 'Recall@k_before_attack', 'Recall@k_after_attack', 'Demotion_value' ] with open(filepath, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(header) for query, gold_set in zip(queries, gold_sets): # Search before attack retrieved_before_list = self.search_before.search( query_features=self.search_before.data.x[query], top_k=max(self.k_values) ) retrieved_before = retrieved_before_list[0] # Extract the first list # Search after attack retrieved_after_list = self.search_after.search( query_features=self.search_after.data.x[query], top_k=max(self.k_values) ) retrieved_after = retrieved_after_list[0] # Extract the first list # Calculate Demotion Value # Find the rank of the query node in retrieved_before and retrieved_after try: rank_before = retrieved_before.index(query) + 1 except ValueError: rank_before = len(retrieved_before) + 1 # Not found try: rank_after = retrieved_after.index(query) + 1 except ValueError: rank_after = len(retrieved_after) + 1 # Not found demotion = rank_after - rank_before # Compute metrics for each k for k in self.k_values: hit_before = hit_at_k(retrieved_before, gold_set, k) hit_after = hit_at_k(retrieved_after, gold_set, k) recall_before = recall_at_k(retrieved_before, gold_set, k) recall_after = recall_at_k(retrieved_after, gold_set, k) result = { 'query_node': query, 'k': k, 'Hit@k_before_attack': hit_before, 'Hit@k_after_attack': hit_after, 'Recall@k_before_attack': recall_before, 'Recall@k_after_attack': recall_after, 'Demotion_value': demotion } self.results.append(result) writer.writerow([ query, k, hit_before, hit_after, recall_before, recall_after, demotion ]) # Compute averages num_results = len(self.results) avg_hit_before = sum(r['Hit@k_before_attack'] for r in self.results) / num_results avg_hit_after = sum(r['Hit@k_after_attack'] for r in self.results) / num_results avg_recall_before = sum(r['Recall@k_before_attack'] for r in self.results) / num_results avg_recall_after = sum(r['Recall@k_after_attack'] for r in self.results) / num_results avg_demotion = sum(r['Demotion_value'] for r in self.results) / num_results # Write averages to CSV writer.writerow([ 'Average', '', avg_hit_before, avg_hit_after, avg_recall_before, avg_recall_after, avg_demotion ]) print(f"Evaluation results saved to {filepath}")
[docs] def get_all_results(self) -> List[Dict[str, Any]]: """ Retrieve all evaluation results. Returns: List[Dict[str, Any]]: List of evaluation result dictionaries. """ return self.results