import torch
from collections import Counter
from collections.abc import Sequence
[docs]
def compare_original_vs_updated(original_data, updated_data):
original_num_nodes = original_data.num_nodes
updated_num_nodes = updated_data.num_nodes
original_num_edges = original_data.edge_index.size(1)
updated_num_edges = updated_data.edge_index.size(1)
print("____________________________________________")
print("- Compare")
print("____________________________________________")
print("Number of nodes (Original):", original_num_nodes)
print("Number of nodes (Updated):", updated_num_nodes)
print("Number of edges (Original):", original_num_edges)
print("Number of edges (Updated):", updated_num_edges)
print("\nDifference in the number of edges:", original_num_edges - updated_num_edges)
print("____________________________________________")
# Find removed edges
original_edges = original_data.edge_index.t().tolist()
updated_edges = updated_data.edge_index.t().tolist()
# Convert to sets for comparison
original_edges_set = set(map(tuple, original_edges))
updated_edges_set = set(map(tuple, updated_edges))
removed_edges = original_edges_set - updated_edges_set
#if removed_edges:
# print("\nRemoved edges:")
# for edge in removed_edges:
# print(edge)
#else:
# print("\nNo edges have been removed.")
print("____________________________________________")
[docs]
def remove_edges(data, edges_to_remove, inplace=False):
"""
Remove specified edges from the given undirected graph data object.
This function modifies the `data.edge_index` attribute by removing the
specified edges. It can handle either a single list of edges or a list
of lists of edges. Since the graph is undirected, (u, v) and (v, u)
are considered the same edge.
Args:
data: A PyG Data object with an `edge_index` attribute.
edges_to_remove (List[Tuple[int, int]] or List[List[Tuple[int, int]]]):
A collection of edges to remove. For example:
- [(u1, v1), (u2, v2), ...]
- [[(u1, v1), (u2, v2)], [(u3, v3), ...]]
inplace (bool, optional): If True, modifies the input data object in-place. Default is False.
Returns:
The modified data object with the specified edges removed.
"""
if not inplace:
data = data.clone()
# If edges_to_remove is a list of lists, flatten it
if any(isinstance(item, list) for item in edges_to_remove):
edges_to_remove = [edge for sublist in edges_to_remove for edge in sublist]
# Normalize edges so that (u, v) and (v, u) represent the same edge
edges_to_remove = [tuple(sorted(edge)) for edge in edges_to_remove]
# Remove duplicates by converting to a set
edges_to_remove = set(edges_to_remove)
# Convert data.edge_index into a list of edges
edge_list = [(int(u.item()), int(v.item())) for u, v in zip(data.edge_index[0], data.edge_index[1])]
# Filter the edges: keep only those not in edges_to_remove
# Since graph is undirected, we sort each edge before checking
filtered_edges = [e for e in edge_list if tuple(sorted(e)) not in edges_to_remove]
# Convert filtered edges back to a tensor
if len(filtered_edges) > 0:
filtered_edges_tensor = torch.tensor(filtered_edges, dtype=torch.long, device=data.edge_index.device).t()
else:
# If no edges remain, create an empty edge_index
filtered_edges_tensor = torch.empty((2, 0), dtype=torch.long, device=data.edge_index.device)
data.edge_index = filtered_edges_tensor
return data
[docs]
def removed_edges_list_stat(data, removed_edges_list, verbose=True):
"""
Calculate and report statistics about a list of removed edges, including
checking if the reverse direction of these edges exists in the main graph.
This function takes a list of edge lists that represent edges removed during
multiple perturbation operations and aggregates them to determine:
- The total number of removed edges across all operations.
- The number of duplicate edges (edges that have appeared more than once across all operations).
- The number of unique edges that have been removed overall.
- The number of removed edges including their reversed counterpart present
in the main graph (`data.edge_index`).
If `verbose` is True, it prints these statistics and does not return anything.
If `verbose` is Flase, it returns the statistics as a tuple.
Args:
removed_edges_list (List[List[Tuple[int, int]]]): A list of lists,
where each inner list contains tuples representing edges that
were removed in a particular operation.
data: A PyG Data object that contains the main graph edges in `data.edge_index`.
verbose (bool, optional): If True, prints out the statistics. Defaults to True.
Returns:
Tuple[int, int, int, int]:
A tuple containing:
- int: The total number of removed edges.
- int: The number of duplicate edges across all operations.
- int: The number of unique edges removed overall.
- int: The number of removed edges including their reversed counterpart
in the main graph.
"""
# Flatten the list if it is a list of lists
if any(isinstance(item, list) for item in removed_edges_list):
removed_edges_list = [edge for sublist in removed_edges_list for edge in sublist]
# Normalize edges so that (u, v) and (v, u) are treated the same
# for counting unique edges and duplicates.
# This ensures (u, v) and (v, u) become the same tuple (min(u,v), max(u,v)).
removed_edges_list = [tuple(sorted(edge)) for edge in removed_edges_list]
removed_edges_list_num_edges = len(removed_edges_list)
removed_edged_unique = set(removed_edges_list)
removed_edges_unique_num_edges = len(removed_edged_unique)
duplicates = removed_edges_list_num_edges - removed_edges_unique_num_edges
# Create a set of edges from the main graph
# For checking reverse counterparts, we do NOT normalize here, because
# we want to distinguish (u, v) from (v, u).
graph_edges = {(int(u.item()), int(v.item())) for u, v in zip(data.edge_index[0], data.edge_index[1])}
# Count how many removed edges have a reversed counterpart in the main graph.
# We'll use the removed_edged_unique here because we want to check
# the actual directionality of how they were removed.
reversed_counterparts_count = 0
for (u, v) in removed_edged_unique:
# The reversed edge is (v, u)
if (v, u) in graph_edges:
reversed_counterparts_count += 1
if verbose:
print(f'Number of Removed Edges in the list: {removed_edges_list_num_edges}')
print(f'Number of Duplicate Edges: {duplicates}')
print(f'Number of Unique Edges to Remove: {removed_edges_unique_num_edges}')
print(f'Number of Removed Edges with Reversed Counterparts in Graph: {reversed_counterparts_count}')
print(f'Number of Total Removed Edges: {removed_edges_unique_num_edges + reversed_counterparts_count}')
else:
return removed_edges_list_num_edges, duplicates, removed_edges_unique_num_edges, reversed_counterparts_count, removed_edges_unique_num_edges + reversed_counterparts_count