roksana.search_methods package

Submodules

roksana.search_methods.registry module

roksana.search_methods.registry.get_search_method(name: str, data: Any, device: str = None, **kwargs) SearchMethod[source]

Retrieve an instance of the specified search method.

Parameters:
  • name (str) – Name of the search method (e.g., ‘gcn’, ‘gat’, ‘sage’).

  • data (Any) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

  • **kwargs – Additional keyword arguments for the search method.

Returns:

An instance of the requested search method.

Return type:

SearchMethod

Raises:

ValueError – If the specified search method is not registered.

roksana.search_methods.registry.register_search_method(name: str)[source]

Decorator to register a search method class with a given name.

Parameters:

name (str) – The name to register the search method under.

Returns:

The decorator function.

Return type:

Callable

roksana.search_methods.search_methods module

class roksana.search_methods.search_methods.SearchMethod(data: Data, device: str | None = None)[source]

Bases: ABC

Abstract base class for search methods.

abstract __init__(data: Data, device: str | None = None)[source]

Initialize the search method with the given dataset.

Parameters:
  • data (Data) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

abstract search(query_features: Tensor, top_k: int = 10) List[int][source]

Perform a search with the given query features.

Parameters:
  • query_features (torch.Tensor) – Feature vector of the query node.

  • top_k (int, optional) – Number of top similar nodes to retrieve.

Returns:

List of node indices sorted by similarity to the query.

Return type:

List[int]

Module contents

class roksana.search_methods.GATSearch(data: Any, device: str = None, hidden_channels: int = 64, heads: int = 8, epochs: int = 200, lr: float = 0.005)[source]

Bases: SearchMethod

Search method using Graph Attention Networks (GAT).

__init__(data: Any, device: str = None, hidden_channels: int = 64, heads: int = 8, epochs: int = 200, lr: float = 0.005)[source]

Initialize and train the GAT model.

Parameters:
  • data (Any) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

  • hidden_channels (int, optional) – Number of hidden channels in GAT layers.

  • heads (int, optional) – Number of attention heads in GAT layers.

  • epochs (int, optional) – Number of training epochs.

  • lr (float, optional) – Learning rate for the optimizer.

evaluate() float[source]

Evaluate the model’s accuracy on the training set.

Returns:

Training accuracy.

Return type:

float

get_node_embeddings() Tensor[source]

Generate node embeddings by passing the data through the model.

Returns:

Node embeddings.

Return type:

torch.Tensor

search(query_features: Tensor, top_k: int = 10) List[int][source]

Perform a search with the given query features using GAT embeddings.

Parameters:
  • query_features (torch.Tensor) – Feature vector of the query node.

  • top_k (int, optional) – Number of top similar nodes to retrieve.

Returns:

List of node indices sorted by similarity to the query.

Return type:

List[int]

train_model()[source]

Train the GAT model on the dataset. Assumes that the dataset has a ‘y’ attribute for node labels.

class roksana.search_methods.GCNSearch(data: Any, device: str = None, hidden_channels: int = 64, epochs: int = 200, lr: float = 0.01)[source]

Bases: SearchMethod

Search method using Graph Convolutional Networks (GCN).

__init__(data: Any, device: str = None, hidden_channels: int = 64, epochs: int = 200, lr: float = 0.01)[source]

Initialize and train the GCN model.

Parameters:
  • data (Any) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

  • hidden_channels (int, optional) – Number of hidden channels in GCN layers.

  • epochs (int, optional) – Number of training epochs.

  • lr (float, optional) – Learning rate for the optimizer.

evaluate() float[source]

Evaluate the model’s accuracy on the training set.

Returns:

Training accuracy.

Return type:

float

get_node_embeddings() Tensor[source]

Generate node embeddings by passing the data through the model.

Returns:

Node embeddings.

Return type:

torch.Tensor

search(query_features: Tensor, top_k: int = 10) List[List[int]][source]

Perform a search with the given query features using GCN embeddings.

Parameters:
  • query_features (torch.Tensor) – Feature tensor of the query nodes, shape [num_queries, feature_dim] or [feature_dim].

  • top_k (int, optional) – Number of top similar nodes to retrieve.

Returns:

List of lists containing node indices sorted by similarity to each query.

Return type:

List[List[int]]

train_model()[source]

Train the GCN model on the dataset. Assumes that the dataset has a ‘y’ attribute for node labels.

class roksana.search_methods.SAGESearch(data: Any, device: str = None, hidden_channels: int = 64, epochs: int = 200, lr: float = 0.01)[source]

Bases: SearchMethod

Search method using GraphSAGE.

__init__(data: Any, device: str = None, hidden_channels: int = 64, epochs: int = 200, lr: float = 0.01)[source]

Initialize and train the GraphSAGE model.

Parameters:
  • data (Any) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

  • hidden_channels (int, optional) – Number of hidden channels in SAGE layers.

  • epochs (int, optional) – Number of training epochs.

  • lr (float, optional) – Learning rate for the optimizer.

evaluate() float[source]

Evaluate the model’s accuracy on the training set.

Returns:

Training accuracy.

Return type:

float

get_node_embeddings() Tensor[source]

Generate node embeddings by passing the data through the model.

Returns:

Node embeddings.

Return type:

torch.Tensor

search(query_features: Tensor, top_k: int = 10) List[int][source]

Perform a search with the given query features using GraphSAGE embeddings.

Parameters:
  • query_features (torch.Tensor) – Feature vector of the query node.

  • top_k (int, optional) – Number of top similar nodes to retrieve.

Returns:

List of node indices sorted by similarity to the query.

Return type:

List[int]

train_model()[source]

Train the GraphSAGE model on the dataset. Assumes that the dataset has a ‘y’ attribute for node labels.

class roksana.search_methods.SearchMethod(data: Any, device: str = None, **kwargs)[source]

Bases: ABC

Abstract base class for search methods.

abstract __init__(data: Any, device: str = None, **kwargs)[source]

Initialize the search method with the given dataset.

Parameters:
  • data (Any) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

abstract search(query_features: Any, top_k: int = 10) List[int][source]

Perform a search with the given query features.

Parameters:
  • query_features (Any) – Feature vector of the query node.

  • top_k (int, optional) – Number of top similar nodes to retrieve.

Returns:

List of node indices sorted by similarity to the query.

Return type:

List[int]

roksana.search_methods.get_search_method(name: str, data: Any, device: str = None, **kwargs) SearchMethod[source]

Retrieve an instance of the specified search method.

Parameters:
  • name (str) – Name of the search method (e.g., ‘gcn’, ‘gat’, ‘sage’).

  • data (Any) – The graph dataset.

  • device (str, optional) – Device to run the computations on (‘cpu’ or ‘cuda’).

  • **kwargs – Additional keyword arguments for the search method.

Returns:

An instance of the requested search method.

Return type:

SearchMethod

Raises:

ValueError – If the specified search method is not registered.