Source code for roksana.attack_methods.random

"""
RandomAttack Module
--------------------
This module implements the RandomAttack class for adversarial attacks on graphs by randomly removing edges 
connected to selected nodes.

Classes:
    - RandomAttack: A class to perform random edge removal attacks on a graph dataset.
"""

from .base_attack import BaseAttack
import torch
from torch_geometric.utils import remove_self_loops, to_undirected
import random
from typing import Any, Tuple, List


[docs] class RandomAttack(BaseAttack): """ RandomAttack Class ------------------- Implements an adversarial attack that randomly removes edges connected to specified nodes in a graph. Attributes: data (Any): The graph dataset. params (dict): Additional parameters for the attack. """
[docs] def __init__(self, data: Any, **kwargs): """ Initialize the RandomAttack method. Args: data (Any): The graph dataset. **kwargs: Additional parameters for the attack. """ self.data = data self.params = kwargs self.device = kwargs.get('device', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
[docs] def attack(self, data: Any, selected_nodes: torch.Tensor) -> Tuple[Any, List[Tuple[int, int]]]: """ Perform the random edge removal attack. Args: data (Any): The graph dataset. selected_nodes (torch.Tensor): Nodes for which edges are to be removed. Can be a single node or a tensor of nodes. Returns: Tuple[Any, List[Tuple[int, int]]]: - The modified graph dataset with updated edges. - A list of removed edges. """ # Normalize selected_nodes to be a 1D tensor, even if a single node is passed if isinstance(selected_nodes, torch.Tensor) and selected_nodes.ndimension() == 0: selected_nodes = selected_nodes.unsqueeze(0) selected_nodes = selected_nodes # Prepare the original dataset and edge list original_random_dataset = data edge_index = original_random_dataset.edge_index.clone() edges = edge_index.t().tolist() # Track the number of edges removed to ensure only len(selected_nodes) edges are removed removed_edges_count = 0 removed_edges = [] # List to store removed edges # Loop over each node in selected_nodes to remove one random edge connected to it for node in selected_nodes.tolist(): if removed_edges_count >= len(selected_nodes): break # Stop once len(selected_nodes) edges are removed # Find edges associated with the current node node_edges = [(i, (src, dst)) for i, (src, dst) in enumerate(edges) if src == node or dst == node] if node_edges: # Select a random edge to remove edge_to_remove = random.choice(node_edges) edges.pop(edge_to_remove[0]) removed_edges.append(edge_to_remove[1]) # Track the removed edge removed_edges_count += 1 # Create a new edge_index tensor from the modified edges new_edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() # Remove any self-loops, just in case new_edge_index, _ = remove_self_loops(new_edge_index) # Assign the modified edge_index to create the random_dataset degree_dataset = data.clone() degree_dataset.edge_index = new_edge_index updated_data = degree_dataset return updated_data, removed_edges