Source code for graphix.opengraph

"""Class for open graph states."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar

import networkx as nx

from graphix.flow._find_cflow import find_cflow
from graphix.flow._find_gpflow import AlgebraicOpenGraph, PlanarAlgebraicOpenGraph, compute_correction_matrix
from graphix.flow.core import GFlow, PauliFlow
from graphix.fundamentals import AbstractMeasurement, AbstractPlanarMeasurement

if TYPE_CHECKING:
    from collections.abc import Collection, Iterable, Mapping, Sequence

    from graphix.flow.core import CausalFlow
    from graphix.measurements import Measurement
    from graphix.pattern import Pattern

# TODO: Maybe move these definitions to graphix.fundamentals and graphix.measurements ? Now they are redefined in graphix.flow._find_gpflow, not very elegant.
_M_co = TypeVar("_M_co", bound=AbstractMeasurement, covariant=True)
_PM_co = TypeVar("_PM_co", bound=AbstractPlanarMeasurement, covariant=True)


[docs] @dataclass(frozen=True) class OpenGraph(Generic[_M_co]): """An unmutable dataclass providing a representation of open graph states. Attributes ---------- graph : networkx.Graph[int] The underlying resource-state graph. Nodes represent qubits and edges represent the application of :math:`CZ` gate on the linked nodes. input_nodes : Sequence[int] An ordered sequence of node labels corresponding to the open graph inputs. output_nodes : Sequence[int] An ordered sequence of node labels corresponding to the open graph outputs. measurements : Mapping[int, _M_co] A mapping between the non-output nodes of the open graph (``key``) and their corresponding measurement label (``value``). Measurement labels can be specified as `Measurement`, `Plane` or `Axis` instances. Notes ----- The inputs and outputs of `OpenGraph` instances in Graphix are defined as ordered sequences of node labels. This contrasts the usual definition of open graphs in the literature, where inputs and outputs are unordered sets of nodes labels. This restriction facilitates the interplay with `Pattern` objects, where the order of input and output nodes represents a choice of Hilbert space basis. Example ------- >>> import networkx as nx >>> from graphix.fundamentals import Plane >>> from graphix.opengraph import OpenGraph >>> from graphix.measurements import Measurement >>> >>> graph = nx.Graph([(0, 1), (1, 2)]) >>> measurements = {i: Measurement(0.5 * i, Plane.XY) for i in range(2)} >>> input_nodes = [0] >>> output_nodes = [2] >>> og = OpenGraph(graph, input_nodes, output_nodes, measurements) """ graph: nx.Graph[int] input_nodes: Sequence[int] output_nodes: Sequence[int] measurements: Mapping[int, _M_co] def __post_init__(self) -> None: """Validate the correctness of the open graph.""" all_nodes = set(self.graph.nodes) inputs = set(self.input_nodes) outputs = set(self.output_nodes) if not set(self.measurements).issubset(all_nodes): raise ValueError("All measured nodes must be part of the graph's nodes.") if not inputs.issubset(all_nodes): raise ValueError("All input nodes must be part of the graph's nodes.") if not outputs.issubset(all_nodes): raise ValueError("All output nodes must be part of the graph's nodes.") if outputs & self.measurements.keys(): raise ValueError("Output nodes cannot be measured.") if all_nodes - outputs != self.measurements.keys(): raise ValueError("All non-output nodes must be measured.") if len(inputs) != len(self.input_nodes): raise ValueError("Input nodes contain duplicates.") if len(outputs) != len(self.output_nodes): raise ValueError("Output nodes contain duplicates.") # TODO: Up docstrings and generalise to any type def isclose( self: OpenGraph[Measurement], other: OpenGraph[Measurement], rel_tol: float = 1e-09, abs_tol: float = 0.0 ) -> bool: """Return `True` if two open graphs implement approximately the same unitary operator. Ensures the structure of the graphs are the same and all measurement angles are sufficiently close. This doesn't check they are equal up to an isomorphism. """ if not nx.utils.graphs_equal(self.graph, other.graph): return False if self.input_nodes != other.input_nodes or self.output_nodes != other.output_nodes: return False if set(self.measurements.keys()) != set(other.measurements.keys()): return False return all( m.isclose(other.measurements[node], rel_tol=rel_tol, abs_tol=abs_tol) for node, m in self.measurements.items() ) def to_pattern(self: OpenGraph[Measurement]) -> Pattern: """Extract a deterministic pattern from an `OpenGraph[Measurement]` if it exists. Returns ------- Pattern A deterministic pattern on the open graph. Raises ------ OpenGraphError If the open graph does not have flow. Notes ----- - The open graph instance must be of parametric type `Measurement` to allow for a pattern extraction, otherwise it does not contain information about the measurement angles. - This method proceeds by searching a flow on the open graph and converting it into a pattern as prescripted in Ref. [1]. It first attempts to find a causal flow because the corresponding flow-finding algorithm has lower complexity. If it fails, it attemps to find a Pauli flow because this property is more general than a generalised flow, and the corresponding flow-finding algorithms have the same complexity in the current implementation. References ---------- [1] Browne et al., NJP 9, 250 (2007) """ for extractor in (self.find_causal_flow, self.find_pauli_flow): flow = extractor() if flow is not None: return flow.to_corrections().to_pattern() raise OpenGraphError("The open graph does not have flow. It does not support a deterministic pattern.") def neighbors(self, nodes: Collection[int]) -> set[int]: """Return the set containing the neighborhood of a set of nodes in the open graph. Parameters ---------- nodes : Collection[int] Set of nodes whose neighborhood is to be found Returns ------- neighbors_set : set[int] Neighborhood of set `nodes`. """ neighbors_set: set[int] = set() for node in nodes: neighbors_set |= set(self.graph.neighbors(node)) return neighbors_set def odd_neighbors(self, nodes: Collection[int]) -> set[int]: """Return the set containing the odd neighborhood of a set of nodes in the open graph. Parameters ---------- nodes : Collection[int] Set of nodes whose odd neighborhood is to be found Returns ------- odd_neighbors_set : set[int] Odd neighborhood of set `nodes`. """ odd_neighbors_set: set[int] = set() for node in nodes: odd_neighbors_set ^= self.neighbors([node]) return odd_neighbors_set def extract_causal_flow(self: OpenGraph[_PM_co]) -> CausalFlow[_PM_co]: """Try to extract a causal flow on the open graph. This method is a wrapper over :func:`OpenGraph.find_causal_flow` with a single return type. Returns ------- CausalFlow[_PM_co] A causal flow object if the open graph has causal flow. Raises ------ OpenGraphError If the open graph does not have a causal flow. See Also -------- :func:`OpenGraph.find_causal_flow` """ cf = self.find_causal_flow() if cf is None: raise OpenGraphError("The open graph does not have a causal flow.") return cf def extract_gflow(self: OpenGraph[_PM_co]) -> GFlow[_PM_co]: r"""Try to extract a maximally delayed generalised flow (gflow) on the open graph. This method is a wrapper over :func:`OpenGraph.find_gflow` with a single return type. Returns ------- GFlow[_PM_co] A gflow object if the open graph has gflow. Raises ------ OpenGraphError If the open graph does not have a gflow. See Also -------- :func:`OpenGraph.find_gflow` """ gf = self.find_gflow() if gf is None: raise OpenGraphError("The open graph does not have a gflow.") return gf def extract_pauli_flow(self: OpenGraph[_M_co]) -> PauliFlow[_M_co]: r"""Try to extract a maximally delayed Pauli on the open graph. This method is a wrapper over :func:`OpenGraph.find_pauli_flow` with a single return type. Returns ------- PauliFlow[_M_co] A Pauli flow object if the open graph has Pauli flow. Raises ------ OpenGraphError If the open graph does not have a Pauli flow. See Also -------- :func:`OpenGraph.find_pauli_flow` """ pf = self.find_pauli_flow() if pf is None: raise OpenGraphError("The open graph does not have a Pauli flow.") return pf def find_causal_flow(self: OpenGraph[_PM_co]) -> CausalFlow[_PM_co] | None: """Return a causal flow on the open graph if it exists. Returns ------- CausalFlow[_PM_co] | None A causal flow object if the open graph has causal flow or ``None`` otherwise. See Also -------- :func:`OpenGraph.extract_causal_flow` Notes ----- - The open graph instance must be of parametric type `Measurement` or `Plane` since the causal flow is only defined on open graphs with :math:`XY` measurements. - This function implements the algorithm presented in Ref. [1] with polynomial complexity on the number of nodes, :math:`O(N^2)`. References ---------- [1] Mhalla and Perdrix, (2008), Finding Optimal Flows Efficiently, doi.org/10.1007/978-3-540-70575-8_70 """ return find_cflow(self) def find_gflow(self: OpenGraph[_PM_co]) -> GFlow[_PM_co] | None: r"""Return a maximally delayed Pauli on the open graph if it exists. Returns ------- GFlow[_PM_co] | None A gflow object if the open graph has gflow or ``None`` otherwise. See Also -------- :func:`OpenGraph.extract_gflow` Notes ----- - The open graph instance must be of parametric type `Measurement` or `Plane` since the gflow is only defined on open graphs with planar measurements. Measurement instances with a Pauli angle (integer multiple of :math:`\pi/2`) are interpreted as `Plane` instances, in contrast with :func:`OpenGraph.find_pauli_flow`. - This function implements the algorithm presented in Ref. [1] with polynomial complexity on the number of nodes, :math:`O(N^3)`. References ---------- [1] Mitosek and Backens, 2024 (arXiv:2410.23439). """ aog = PlanarAlgebraicOpenGraph(self) correction_matrix = compute_correction_matrix(aog) if correction_matrix is None: return None return GFlow.try_from_correction_matrix( correction_matrix ) # The constructor returns `None` if the correction matrix is not compatible with any partial order on the open graph. def find_pauli_flow(self: OpenGraph[_M_co]) -> PauliFlow[_M_co] | None: r"""Return a maximally delayed Pauli on the open graph if it exists. Returns ------- PauliFlow[_M_co] | None A Pauli flow object if the open graph has Pauli flow or ``None`` otherwise. See Also -------- :func:`OpenGraph.extract_pauli_flow` Notes ----- - Measurement instances with a Pauli angle (integer multiple of :math:`\pi/2`) are interpreted as `Axis` instances, in contrast with :func:`OpenGraph.find_gflow`. - This function implements the algorithm presented in Ref. [1] with polynomial complexity on the number of nodes, :math:`O(N^3)`. References ---------- [1] Mitosek and Backens, 2024 (arXiv:2410.23439). """ aog = AlgebraicOpenGraph(self) correction_matrix = compute_correction_matrix(aog) if correction_matrix is None: return None return PauliFlow.try_from_correction_matrix( correction_matrix ) # The constructor returns `None` if the correction matrix is not compatible with any partial order on the open graph. # TODO: Generalise `compose` to any type of OpenGraph def compose( self: OpenGraph[Measurement], other: OpenGraph[Measurement], mapping: Mapping[int, int] ) -> tuple[OpenGraph[Measurement], dict[int, int]]: r"""Compose two open graphs by merging subsets of nodes from `self` and `other`, and relabeling the nodes of `other` that were not merged. Parameters ---------- other : OpenGraph Open graph to be composed with `self`. mapping: dict[int, int] Partial relabelling of the nodes in `other`, with `keys` and `values` denoting the old and new node labels, respectively. Returns ------- og: OpenGraph composed open graph mapping_complete: dict[int, int] Complete relabelling of the nodes in `other`, with `keys` and `values` denoting the old and new node label, respectively. Notes ----- Let's denote :math:`\{G(V_1, E_1), I_1, O_1\}` the open graph `self`, :math:`\{G(V_2, E_2), I_2, O_2\}` the open graph `other`, :math:`\{G(V, E), I, O\}` the resulting open graph `og` and `{v:u}` an element of `mapping`. We define :math:`V, U` the set of nodes in `mapping.keys()` and `mapping.values()`, and :math:`M = U \cap V_1` the set of merged nodes. The open graph composition requires that - :math:`V \subseteq V_2`. - If both `v` and `u` are measured, the corresponding measurements must have the same plane and angle. The returned open graph follows this convention: - :math:`I = (I_1 \cup I_2) \setminus M \cup (I_1 \cap I_2 \cap M)`, - :math:`O = (O_1 \cup O_2) \setminus M \cup (O_1 \cap O_2 \cap M)`, - If only one node of the pair `{v:u}` is measured, this measure is assigned to :math:`u \in V` in the resulting open graph. - Input (and, respectively, output) nodes in the returned open graph have the order of the open graph `self` followed by those of the open graph `other`. Merged nodes are removed, except when they are input (or output) nodes in both open graphs, in which case, they appear in the order they originally had in the graph `self`. """ if not (mapping.keys() <= other.graph.nodes): raise ValueError("Keys of mapping must be correspond to nodes of other.") if len(mapping) != len(set(mapping.values())): raise ValueError("Values in mapping contain duplicates.") for v, u in mapping.items(): if ( (vm := other.measurements.get(v)) is not None and (um := self.measurements.get(u)) is not None and not vm.isclose(um) ): raise ValueError(f"Attempted to merge nodes {v}:{u} but have different measurements") shift = max(*self.graph.nodes, *mapping.values()) + 1 mapping_sequential = { node: i for i, node in enumerate(sorted(other.graph.nodes - mapping.keys()), start=shift) } # assigns new labels to nodes in other not specified in mapping mapping_complete = {**mapping, **mapping_sequential} g2_shifted = nx.relabel_nodes(other.graph, mapping_complete) g = nx.compose(self.graph, g2_shifted) merged = set(mapping_complete.values()) & self.graph.nodes def merge_ports(p1: Iterable[int], p2: Iterable[int]) -> list[int]: p2_mapped = [mapping_complete[node] for node in p2] p2_set = set(p2_mapped) part1 = [node for node in p1 if node not in merged or node in p2_set] part2 = [node for node in p2_mapped if node not in merged] return part1 + part2 inputs = merge_ports(self.input_nodes, other.input_nodes) outputs = merge_ports(self.output_nodes, other.output_nodes) measurements_shifted = {mapping_complete[i]: meas for i, meas in other.measurements.items()} measurements = {**self.measurements, **measurements_shifted} return OpenGraph(g, inputs, outputs, measurements), mapping_complete
class OpenGraphError(Exception): """Exception subclass to handle open graphs errors."""