Source code for graphix.graphsim.rxgraphviews

"""Node list class for RXGraphState."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from collections.abc import Iterator


[docs] class NodeList: """Node list class for RXGraphState. In rustworkx, node data is stored in a tuple (node_num, node_data), and adding/removing nodes by node_num is not supported. This class defines a node list with node_num as key. """
[docs] def __init__( self, node_nums: list[int] | None = None, node_datas: list[dict] | None = None, node_indices: list[int] | None = None, ): """Initialize a node list.""" if node_indices is None: node_indices = [] if node_datas is None: node_datas = [] if node_nums is None: node_nums = [] if not (len(node_nums) == len(node_datas) and len(node_nums) == len(node_indices)): raise ValueError("node_nums, node_datas and node_indices must have the same length") self.nodes = set(node_nums) self.num_to_data = {nnum: node_datas[nidx] for nidx, nnum in zip(node_indices, node_nums)} self.num_to_idx = {nnum: nidx for nidx, nnum in zip(node_indices, node_nums)} self.idx_to_num = {nidx: nnum for nidx, nnum in zip(node_indices, node_nums)}
def __contains__(self, nnum: int) -> bool: """Return `True` if the node `nnum` belongs to the list, `False` otherwise.""" return nnum in self.nodes def __getitem__(self, nnum: int) -> Any: """Return the data associated to node `nnum`.""" return self.num_to_data[nnum] def __len__(self) -> int: """Return the number of nodes.""" return len(self.nodes) def __iter__(self) -> Iterator[int]: """Return an iterator over nodes.""" return iter(self.nodes) # TODO: This is not an evaluable __repr__. Define __str__ instead? def __repr__(self) -> str: """Return a string representation for the node list.""" return "NodeList" + str(list(self.nodes))
[docs] def get_node_index(self, nnum: int) -> int: """Return the index of the node `nnum`.""" return self.num_to_idx[nnum]
[docs] def add_node(self, nnum: int, ndata: dict, nidx: int) -> None: """Add a node to the list.""" if nnum in self.num_to_data: raise ValueError(f"Node {nnum} already exists") self.nodes.add(nnum) self.num_to_data[nnum] = ndata self.num_to_idx[nnum] = nidx self.idx_to_num[nidx] = nnum
[docs] def add_nodes_from(self, node_nums: list[int], node_datas: list[dict], node_indices: list[int]) -> None: """Add nodes to the list.""" if not (len(node_nums) == len(node_datas) and len(node_nums) == len(node_indices)): raise ValueError("node_nums, node_datas and node_indices must have the same length") for nnum, ndata, nidx in zip(node_nums, node_datas, node_indices): if nnum in self.nodes: continue self.add_node(nnum, ndata, nidx)
[docs] def remove_node(self, nnum: int) -> None: """Remove a node from the list.""" if nnum not in self.num_to_data: raise ValueError(f"Node {nnum} does not exist") self.nodes.remove(nnum) del self.num_to_data[nnum] idx = self.num_to_idx.pop(nnum) del self.idx_to_num[idx]
[docs] def remove_nodes_from(self, node_nums: list[int]) -> None: """Remove nodes from the list.""" for nnum in node_nums: if nnum not in self.nodes: continue self.remove_node(nnum)
[docs] class EdgeList: """Edge list class for RXGraphState. In rustworkx, edge data is stored in a tuple (parent, child, edge_data), and adding/removing edges by (parent, child) is not supported. This class defines a edge list with (parent, child) as key. """
[docs] def __init__( self, edge_nums: list[tuple[int, int]] | None = None, edge_datas: list[dict] | None = None, edge_indices: list[int] | None = None, ): """Initialize an edge list.""" if edge_indices is None: edge_indices = [] if edge_datas is None: edge_datas = [] if edge_nums is None: edge_nums = [] if not (len(edge_nums) == len(edge_datas) and len(edge_nums) == len(edge_indices)): raise ValueError("edge_nums, edge_datas and edge_indices must have the same length") self.edges = set(edge_nums) self.num_to_data = {enum: edge_datas[eidx] for eidx, enum in zip(edge_indices, edge_nums)} self.num_to_idx = {enum: eidx for eidx, enum in zip(edge_indices, edge_nums)} self.nnum_to_edges = {} for enum in edge_nums: if enum[0] not in self.nnum_to_edges: self.nnum_to_edges[enum[0]] = set() if enum[1] not in self.nnum_to_edges: self.nnum_to_edges[enum[1]] = set() self.nnum_to_edges[enum[0]].add(enum) self.nnum_to_edges[enum[1]].add(enum)
def __contains__(self, enum: tuple[int, int]) -> bool: """Return `True` if the edge `enum` belongs to the list, `False` otherwise.""" return enum in self.edges def __getitem__(self, enum: tuple[int, int]) -> Any: """Return the data associated to edge `enum`.""" return self.num_to_data[enum] def __len__(self): """Return the number of edges.""" return len(self.edges) def __iter__(self) -> Iterator[int]: """Return an iterator over edges.""" return iter(self.edges) # TODO: This is not an evaluable __repr__. Define __str__ instead? def __repr__(self) -> str: """Return a string representation for the edge list.""" return "EdgeList" + str(list(self.edges))
[docs] def get_edge_index(self, enum: tuple[int, int]) -> int: """Return the index of the edge `enum`.""" return self.num_to_idx[enum]
[docs] def add_edge(self, enum: tuple[int, int], edata: dict, eidx: int) -> None: """Add an edge to the list.""" if enum in self.num_to_data: raise ValueError(f"Edge {enum} already exists") self.edges.add(enum) self.num_to_data[enum] = edata self.num_to_idx[enum] = eidx if enum[0] not in self.nnum_to_edges: self.nnum_to_edges[enum[0]] = set() if enum[1] not in self.nnum_to_edges: self.nnum_to_edges[enum[1]] = set() self.nnum_to_edges[enum[0]].add(enum) self.nnum_to_edges[enum[1]].add(enum)
[docs] def add_edges_from(self, edge_nums: list[tuple[int, int]], edge_datas: list[dict], edge_indices: list[int]) -> None: """Add edges to the list.""" if not (len(edge_nums) == len(edge_datas) and len(edge_nums) == len(edge_indices)): raise ValueError("edge_nums, edge_datas and edge_indices must have the same length") for enum, edata, eidx in zip(edge_nums, edge_datas, edge_indices): if enum in self.edges: continue self.add_edge(enum, edata, eidx)
[docs] def remove_edge(self, enum: tuple[int, int]) -> None: """Remove an edge from the list.""" if enum not in self.num_to_data: raise ValueError(f"Edge {enum} does not exist") self.edges.remove(enum) del self.num_to_data[enum] del self.num_to_idx[enum] if enum[0] not in self.nnum_to_edges: self.nnum_to_edges[enum[0]] = set() if enum[1] not in self.nnum_to_edges: self.nnum_to_edges[enum[1]] = set() self.nnum_to_edges[enum[0]].remove(enum) self.nnum_to_edges[enum[1]].remove(enum)
[docs] def remove_edges_from(self, edge_nums: list[tuple[int, int]]) -> None: """Remove edges from the list.""" for enum in edge_nums: if enum not in self.edges: continue self.remove_edge(enum)
[docs] def remove_edges_by_node(self, nnum: int): """Remove all edges connected to the node `nnum`.""" if nnum in self.nnum_to_edges: for enum in self.nnum_to_edges[nnum]: self.edges.remove(enum) del self.num_to_data[enum] del self.num_to_idx[enum] if enum[0] == nnum: self.nnum_to_edges[enum[1]].remove(enum) else: self.nnum_to_edges[enum[0]].remove(enum)