Source code for roksana.search_methods.registry

# roksana/search_methods/registry.py

from typing import Type, Dict, Any
from .base_search import SearchMethod
from .gcn_search import GCNSearch
from .gat_search import GATSearch
from .sage_search import SAGESearch

# Registry dictionary to hold search method classes
SEARCH_METHODS: Dict[str, Type[SearchMethod]] = {
    'gcn': GCNSearch,
    'gat': GATSearch,
    'sage': SAGESearch,
    # Add more predefined attacks here
}

[docs] def register_search_method(name: str): """ Decorator to register a search method class with a given name. Args: name (str): The name to register the search method under. Returns: Callable: The decorator function. """ def decorator(cls: Type[SearchMethod]): if not issubclass(cls, SearchMethod): raise ValueError("Can only register subclasses of SearchMethod") SEARCH_METHODS[name.lower()] = cls return cls return decorator
[docs] def get_search_method(name: str, data: Any, device: str = None, **kwargs) -> SearchMethod: """ Retrieve an instance of the specified search method. Args: name (str): Name of the search method (e.g., 'gcn', 'gat', 'sage'). data (Any): The graph dataset. device (str, optional): Device to run the computations on ('cpu' or 'cuda'). **kwargs: Additional keyword arguments for the search method. Returns: SearchMethod: An instance of the requested search method. Raises: ValueError: If the specified search method is not registered. """ name = name.lower() if name not in SEARCH_METHODS: raise ValueError(f"Search method '{name}' not found. Available methods: {list(SEARCH_METHODS.keys())}") search_method_class = SEARCH_METHODS[name] return search_method_class(data, device=device, **kwargs)