Source code for graphix.pattern

"""MBQC pattern according to Measurement Calculus.

ref: V. Danos, E. Kashefi and P. Panangaden. J. ACM 54.2 8 (2007)
"""

from __future__ import annotations

import copy
import dataclasses
import enum
import itertools
import warnings
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Literal, SupportsFloat, overload

import networkx as nx
from typing_extensions import assert_never

from graphix import command, optimization, parameter
from graphix.clifford import Clifford
from graphix.command import Command, CommandKind
from graphix.flow.exceptions import FlowError
from graphix.fundamentals import Axis, Plane, Sign
from graphix.graphsim import GraphState
from graphix.measurements import Measurement, Outcome, PauliMeasurement, toggle_outcome
from graphix.opengraph import OpenGraph
from graphix.pretty_print import OutputFormat, pattern_to_str
from graphix.qasm3_exporter import pattern_to_qasm3_lines
from graphix.sim import DensityMatrix, MBQCTensorNet, Statevec
from graphix.simulator import PatternSimulator
from graphix.states import BasicStates
from graphix.visualization import GraphVisualizer

if TYPE_CHECKING:
    from collections.abc import Container, Iterator, Mapping
    from collections.abc import Set as AbstractSet
    from typing import Any

    from numpy.random import Generator

    from graphix.flow.core import CausalFlow, GFlow, XZCorrections
    from graphix.parameter import ExpressionOrSupportsComplex, ExpressionOrSupportsFloat, Parameter
    from graphix.sim import Backend, Data, DensityMatrixBackend, StatevectorBackend
    from graphix.sim.base_backend import _StateT_co
    from graphix.sim.tensornet import TensorNetworkBackend
    from graphix.simulator import _BackendLiteral
    from graphix.states import State

_BuiltinBackendState = DensityMatrix | Statevec | MBQCTensorNet


[docs] class Pattern: """ MBQC pattern class. Pattern holds a sequence of commands to operate the MBQC (Pattern.seq), and provide modification strategies to improve the structure and simulation efficiency of the pattern accoring to measurement calculus. ref: V. Danos, E. Kashefi and P. Panangaden. J. ACM 54.2 8 (2007) Attributes ---------- list(self) : list of commands. .. line-block:: each command is a list [type, nodes, attr] which will be applied in the order of list indices. type: one of {'N', 'M', 'E', 'X', 'Z', 'S', 'C'} nodes: int for {'N', 'M', 'X', 'Z', 'S', 'C'} commands, tuple (i, j) for {'E'} command attr for N: none attr for M: meas_plane, angle, s_domain, t_domain attr for X: signal_domain attr for Z: signal_domain attr for S: signal_domain attr for C: clifford_index, as defined in :py:mod:`graphix.clifford` n_node : int total number of nodes in the resource state """ results: dict[int, Outcome] __seq: list[Command]
[docs] def __init__( self, input_nodes: Iterable[int] | None = None, cmds: Iterable[Command] | None = None, output_nodes: Iterable[int] | None = None, ) -> None: """ Construct a pattern. Parameters ---------- input_nodes : Iterable[int] | None Optional. List of input qubits. cmds : Iterable[Command] | None Optional. List of initial commands. output_nodes : Iterable[int] | None Optional. List of output qubits. """ self.results = {} # measurement results from the graph state simulator if input_nodes is None: self.__input_nodes = [] else: self.__input_nodes = list(input_nodes) # input nodes (list() makes our own copy of the list) self.__n_node = len(self.__input_nodes) # total number of nodes in the graph state self.__seq = [] # output nodes are initially a copy input nodes, since none are measured yet self.__output_nodes = list(self.__input_nodes) if cmds is not None: self.extend(cmds) if output_nodes is not None: self.reorder_output_nodes(output_nodes)
[docs] def add(self, cmd: Command) -> None: """Add command to the end of the pattern. An MBQC command is an instance of :class:`graphix.command.Command`. Parameters ---------- cmd : :class:`graphix.command.Command` MBQC command. """ if cmd.kind == CommandKind.N: self.__n_node += 1 self.__output_nodes.append(cmd.node) elif cmd.kind == CommandKind.M: if cmd.node in self.__output_nodes: self.__output_nodes.remove(cmd.node) self.__seq.append(cmd)
[docs] def extend(self, *cmds: Command | Iterable[Command]) -> None: """Add sequences of commands. :param cmds: sequences of commands """ for item in cmds: if isinstance(item, Iterable): for cmd in item: self.add(cmd) else: self.add(item)
[docs] def clear(self) -> None: """Clear the sequence of pattern commands.""" self.__n_node = len(self.__input_nodes) self.__seq = [] self.__output_nodes = list(self.__input_nodes)
[docs] def replace(self, cmds: list[Command], input_nodes: list[int] | None = None) -> None: """Replace pattern with a given sequence of pattern commands. :param cmds: list of commands :param input_nodes: optional, list of input qubits (by default, keep the same input nodes as before) """ if input_nodes is not None: self.__input_nodes = list(input_nodes) self.clear() self.extend(cmds)
def compose( self, other: Pattern, mapping: Mapping[int, int], preserve_mapping: bool = False ) -> tuple[Pattern, dict[int, int]]: r"""Compose two patterns by merging subsets of outputs from `self` and a subset of inputs of `other`, and relabeling the nodes of `other` that were not merged. Parameters ---------- other : Pattern Pattern to be composed with `self`. mapping: Mapping[int, int] Partial relabelling of the nodes in `other`, with `keys` and `values` denoting the old and new node labels, respectively. preserve_mapping: bool Boolean flag controlling the ordering of the output nodes in the returned pattern. Returns ------- p: Pattern composed pattern mapping_complete: dict[int, int] Complete relabelling of the nodes in `other`, with `keys` and `values` denoting the old and new node label, respectively. Notes ----- Let's denote :math:`(I_j, O_j, V_j, S_j)` the ordered set of inputs and outputs, the computational space and the sequence of commands of pattern :math:`P_j`, respectively, with :math:`j = 1` for pattern `self` and :math:`j = 2` for pattern `other`. Let's denote :math:`P` the resulting pattern with :math:`(I, O, V, S)`. Let's denote :math:`K, U` the sets of `keys` and `values` of `mapping`, :math:`M_1 = O_1 \cap U` the set of merged outputs, and :math:`M_2 = \{k \in I_2 \cap K | k \rightarrow v, v \in M_1 \}` the set of merged inputs. The pattern composition requires that - :math:`K \subseteq V_2`. - For a pair :math:`(k, v) \in (K, U)` - :math:`U \cap V_1 \setminus O_1 = \emptyset`. If :math:`v \in O_1`, then :math:`k \in I_2`, otherwise an error is raised. - :math:`v` can always satisfy :math:`v \notin V_1`, thereby allowing a custom relabelling. The returned pattern follows this convention: - Nodes of pattern `other` not specified in `mapping` (i.e., :math:`V_2 \cap K^c`) are relabelled in ascending order. - The sequence of the resulting pattern is :math:`S = S_2 S_1`, where nodes in :math:`S_2` are relabelled according to `mapping`. - :math:`I = I_1 \cup (I_2 \setminus M_2)`. - :math:`O = (O_1 \setminus M_1) \cup O_2`. - Input (and, respectively, output) nodes in the returned pattern have the order of the pattern `self` followed by those of the pattern `other`. Merged nodes are removed. - If `preserve_mapping = True` and :math:`|M_1| = |I_2| = |O_2|`, then the outputs of the returned pattern are the outputs of pattern `self`, where the nth merged output is replaced by the output of pattern `other` corresponding to its nth input instead. """ nodes_p1 = self.extract_nodes() | self.results.keys() # Results contain preprocessed Pauli nodes nodes_p2 = other.extract_nodes() | other.results.keys() if not mapping.keys() <= nodes_p2: raise ValueError("Keys of `mapping` must correspond to the nodes of `other`.") # Cast to set for improved performance in membership test mapping_values_set = set(mapping.values()) o1_set = set(self.__output_nodes) i2_set = set(other.input_nodes) if len(mapping) != len(mapping_values_set): raise ValueError("Values of `mapping` contain duplicates.") if mapping_values_set & nodes_p1 - o1_set: raise ValueError("Values of `mapping` must not contain measured nodes of pattern `self`.") for k, v in mapping.items(): if v in o1_set and k not in i2_set: raise ValueError( f"Mapping {k} -> {v} is not valid. {v} is an output of pattern `self` but {k} is not an input of pattern `other`." ) # Check if resulting pattern will have C commands before E commands if any(cmd.kind == CommandKind.C for cmd in self.__seq) and any(cmd.kind == CommandKind.E for cmd in other): warnings.warn( r"Pattern `self` contains Clifford commands and pattern `other` contains E commands. Standardization might not be possible for the resulting composed pattern.", stacklevel=2, ) shift = max(*nodes_p1, *mapping.values()) + 1 mapping_sequential = { node: i for i, node in enumerate(sorted(nodes_p2 - mapping.keys()), start=shift) } # assigns new labels to nodes in other not specified in mapping mapping_complete = {**mapping, **mapping_sequential} mapped_inputs = [mapping_complete[n] for n in other.input_nodes] mapped_outputs = [mapping_complete[n] for n in other.output_nodes] mapped_results: dict[int, Outcome] = {mapping_complete[n]: m for n, m in other.results.items()} merged = mapping_values_set.intersection(self.__output_nodes) inputs = self.__input_nodes + [n for n in mapped_inputs if n not in merged] if preserve_mapping and not (len(merged) == len(other.input_nodes) == len(other.output_nodes)): warnings.warn( "`preserve_mapping = True` ignored because the number of merged nodes, inputs, and outputs of pattern `other` are different.", stacklevel=2, ) preserve_mapping = False if preserve_mapping: io_mapping = { mapping[i]: mapping_complete[o] for i, o in zip(other.input_nodes, other.output_nodes, strict=True) } outputs = [io_mapping[n] if n in merged else n for n in self.__output_nodes] else: outputs = [n for n in self.__output_nodes if n not in merged] + mapped_outputs def update_command(cmd: Command) -> Command: # Shallow copy is enough since the mutable attributes of cmd_new susceptible to change are reassigned cmd_new = copy.copy(cmd) if cmd_new.kind is CommandKind.E: i, j = cmd_new.nodes cmd_new.nodes = (mapping_complete[i], mapping_complete[j]) elif cmd_new.kind is not CommandKind.T: cmd_new.node = mapping_complete[cmd_new.node] if cmd_new.kind is CommandKind.M: cmd_new.s_domain = {mapping_complete[i] for i in cmd_new.s_domain} cmd_new.t_domain = {mapping_complete[i] for i in cmd_new.t_domain} # Use of `==` here for mypy elif cmd_new.kind == CommandKind.X or cmd_new.kind == CommandKind.Z or cmd_new.kind == CommandKind.S: # noqa: PLR1714 cmd_new.domain = {mapping_complete[i] for i in cmd_new.domain} return cmd_new seq = self.__seq + [update_command(c) for c in other] results: dict[int, Outcome] = {**self.results, **mapped_results} p = Pattern(input_nodes=inputs, output_nodes=outputs, cmds=seq) p.results = results return p, mapping_complete @property def input_nodes(self) -> list[int]: """List input nodes.""" return list(self.__input_nodes) # copy for preventing modification @property def output_nodes(self) -> list[int]: """List all nodes that are either `input_nodes` or prepared with `N` commands and that have not been measured with an `M` command.""" return list(self.__output_nodes) # copy for preventing modification def __len__(self) -> int: """Return the length of command sequence.""" return len(self.__seq) def __iter__(self) -> Iterator[Command]: """Iterate over commands.""" return iter(self.__seq) def __getitem__(self, index: int) -> Command: """Get the command at a given index.""" return self.__seq[index] @property def n_node(self) -> int: """Count of nodes that are either `input_nodes` or prepared with `N` commands.""" return self.__n_node
[docs] def reorder_output_nodes(self, output_nodes: Iterable[int]) -> None: """Arrange the order of output_nodes. Parameters ---------- output_nodes: iterable of int output nodes order determined by user. each index corresponds to that of logical qubits. """ output_nodes = list(output_nodes) # make our own copy (allow iterators to be passed) assert_permutation(self.__output_nodes, output_nodes) self.__output_nodes = output_nodes
[docs] def reorder_input_nodes(self, input_nodes: Iterable[int]) -> None: """Arrange the order of input_nodes. Parameters ---------- input_nodes: iterable of int input nodes order determined by user. each index corresponds to that of logical qubits. """ input_nodes = list(input_nodes) # make our own copy (allow iterators to be passed) assert_permutation(self.__input_nodes, input_nodes) self.__input_nodes = input_nodes
def __repr__(self) -> str: """Return a representation string of the pattern.""" arguments = [] if self.__input_nodes: arguments.append(f"input_nodes={self.__input_nodes}") if self.__seq: arguments.append(f"cmds={self.__seq}") if self.__output_nodes: arguments.append(f"output_nodes={self.__output_nodes}") return f"Pattern({', '.join(arguments)})" def __str__(self) -> str: """Return a human-readable string of the pattern.""" return self.to_ascii() def __eq__(self, other: object) -> bool: """Return `True` if the two patterns are equal, `False` otherwise.""" if not isinstance(other, Pattern): return NotImplemented return ( self.__seq == other.__seq and self.__input_nodes == other.__input_nodes and self.__output_nodes == other.__output_nodes and self.results == other.results )
[docs] def to_ascii( self, left_to_right: bool = False, limit: int = 40, target: Container[command.CommandKind] | None = None ) -> str: """Return the ASCII string representation of the pattern.""" return pattern_to_str(self, OutputFormat.ASCII, left_to_right, limit, target)
[docs] def to_latex( self, left_to_right: bool = False, limit: int = 40, target: Container[command.CommandKind] | None = None ) -> str: """Return a string containing the LaTeX representation of the pattern.""" return pattern_to_str(self, OutputFormat.LaTeX, left_to_right, limit, target)
[docs] def to_unicode( self, left_to_right: bool = False, limit: int = 40, target: Container[command.CommandKind] | None = None ) -> str: """Return the Unicode string representation of the pattern.""" return pattern_to_str(self, OutputFormat.Unicode, left_to_right, limit, target)
def print_pattern(self, lim: int = 40, target: Container[CommandKind] | None = None) -> None: """Print the pattern sequence (Pattern.seq). This method is deprecated. See :meth:`to_ascii`, :meth:`to_latex`, :meth:`to_unicode` and :func:`graphix.pretty_print.pattern_to_str`. Parameters ---------- lim: int, optional maximum number of commands to show target : list of CommandKind, optional show only specified commands, e.g. [CommandKind.M, CommandKind.X, CommandKind.Z] """ warnings.warn( "Method `print_pattern` is deprecated. Use one of the methods `to_ascii`, `to_latex`, `to_unicode`, or the function `graphix.pretty_print.pattern_to_str`.", DeprecationWarning, stacklevel=1, ) print(pattern_to_str(self, OutputFormat.ASCII, left_to_right=True, limit=lim, target=target))
[docs] def standardize(self) -> None: """Execute standardization of the pattern. 'standard' pattern is one where commands are sorted in the order of 'N', 'E', 'M' and then byproduct commands ('X' and 'Z') and finally Clifford commands ('C'). """ self.__seq = optimization.standardize(self).__seq
[docs] def is_standard(self, strict: bool = False) -> bool: """Determine whether the command sequence is standard. Parameters ---------- strict : bool, optional If True, ensures that C commands are the last ones. Returns ------- is_standard : bool True if the pattern is standard """ it = iter(self) try: kind = next(it).kind while kind == CommandKind.N: kind = next(it).kind while kind == CommandKind.E: kind = next(it).kind while kind == CommandKind.M: kind = next(it).kind if strict: xz = {CommandKind.X, CommandKind.Z} while kind in xz: kind = next(it).kind while kind == CommandKind.C: kind = next(it).kind else: xzc = {CommandKind.X, CommandKind.Z, CommandKind.C} while kind in xzc: kind = next(it).kind except StopIteration: return True else: return False
[docs] def shift_signals(self, method: str = "direct") -> dict[int, set[int]]: """Perform signal shifting procedure. Extract the t-dependence of the measurement into 'S' commands and commute them to the end of the command sequence where it can be removed. This procedure simplifies the dependence structure of the pattern. Ref for the original 'mc' method: V. Danos, E. Kashefi and P. Panangaden. J. ACM 54.2 8 (2007) Parameters ---------- method : str, optional 'direct' shift_signals is executed on a conventional Pattern sequence. 'mc' shift_signals is done using the original algorithm on the measurement calculus paper. Returns ------- signal_dict : dict[int, set[int]] For each node, the signal that have been shifted. """ # Shifting signals could turn non-runnable patterns into # runnable ones, so we check runnability first to avoid hiding # code-logic errors. # For example, the non-runnable pattern {1}[M(0)] N(0) would # become M(0) N(0), which is runnable. self.check_runnability() if method == "direct": return self.shift_signals_direct() if method == "mc": signal_dict = self.extract_signals() target = self._find_op_to_be_moved(CommandKind.S, rev=True) while target is not None: if target == len(self.__seq) - 1: self.__seq.pop(target) target = self._find_op_to_be_moved(CommandKind.S, rev=True) continue cmd = self.__seq[target + 1] kind = cmd.kind if kind == CommandKind.X: self._commute_xs(target) elif kind == CommandKind.Z: self._commute_zs(target) elif kind == CommandKind.M: self._commute_ms(target) elif kind == CommandKind.S: self._commute_ss(target) else: self._commute_with_following(target) target += 1 return signal_dict raise ValueError("Invalid method")
def shift_signals_direct(self) -> dict[int, set[int]]: """Perform signal shifting procedure.""" signal_dict: dict[int, set[int]] = {} def expand_domain(domain: set[command.Node]) -> None: """Expand ``domain`` with previously shifted signals. Parameters ---------- domain : set[int] Set of nodes representing the current domain. This set is modified in place by XORing any previously shifted domains. """ for node in domain & signal_dict.keys(): domain ^= signal_dict[node] for i, cmd in enumerate(self): if cmd.kind == CommandKind.M: s_domain = set(cmd.s_domain) t_domain = set(cmd.t_domain) expand_domain(s_domain) expand_domain(t_domain) plane = cmd.plane if plane == Plane.XY: # M^{XY,α} X^s Z^t = M^{XY,(-1)^s·α+tπ} # = S^t M^{XY,(-1)^s·α} # = S^t M^{XY,α} X^s if t_domain: signal_dict[cmd.node] = t_domain t_domain = set() elif plane == Plane.XZ: # M^{XZ,α} X^s Z^t = M^{XZ,(-1)^t((-1)^s·α+sπ)} # = M^{XZ,(-1)^{s+t}·α+(-1)^t·sπ} # = M^{XZ,(-1)^{s+t}·α+sπ} (since (-1)^t·π ≡ π (mod 2π)) # = S^s M^{XZ,(-1)^{s+t}·α} # = S^s M^{XZ,α} Z^{s+t} if s_domain: signal_dict[cmd.node] = s_domain t_domain ^= s_domain s_domain = set() elif plane == Plane.YZ and s_domain: # M^{YZ,α} X^s Z^t = M^{YZ,(-1)^t·α+sπ)} # = S^s M^{YZ,(-1)^t·α} # = S^s M^{YZ,α} Z^t signal_dict[cmd.node] = s_domain s_domain = set() if s_domain != cmd.s_domain or t_domain != cmd.t_domain: self.__seq[i] = dataclasses.replace(cmd, s_domain=s_domain, t_domain=t_domain) # Use of `==` here for mypy elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 domain = set(cmd.domain) expand_domain(domain) if domain != cmd.domain: self.__seq[i] = dataclasses.replace(cmd, domain=domain) return signal_dict def _find_op_to_be_moved(self, op: CommandKind, rev: bool = False, skipnum: int = 0) -> int | None: """Find a command. Parameters ---------- op : CommandKind, N, E, M, X, Z, S command types to be searched rev : bool search from the end (true) or start (false) of seq skipnum : int skip the detected command by specified times """ if not rev: # Search from the start start_index, end_index, step = 0, len(self.__seq), 1 else: # Search from the end start_index, end_index, step = len(self.__seq) - 1, -1, -1 num_ops = 0 for index in range(start_index, end_index, step): if self.__seq[index].kind == op: num_ops += 1 if num_ops == skipnum + 1: return index # If no target found return None def _commute_ex(self, target: int) -> bool: """Perform the commutation of E and X. Parameters ---------- target : int target command index. this must point to a X command followed by E command """ x = self.__seq[target] e = self.__seq[target + 1] assert x.kind == CommandKind.X assert e.kind == CommandKind.E if e.nodes[0] == x.node: z = command.Z(node=e.nodes[1], domain=x.domain) self.__seq.pop(target + 1) # del E self.__seq.insert(target, z) # add Z in front of X self.__seq.insert(target, e) # add E in front of Z return True if e.nodes[1] == x.node: z = command.Z(node=e.nodes[0], domain=x.domain) self.__seq.pop(target + 1) # del E self.__seq.insert(target, z) # add Z in front of X self.__seq.insert(target, e) # add E in front of Z return True self._commute_with_following(target) return False def _commute_mx(self, target: int) -> bool: """Perform the commutation of M and X. Parameters ---------- target : int target command index. this must point to a X command followed by M command """ x = self.__seq[target] m = self.__seq[target + 1] assert x.kind == CommandKind.X assert m.kind == CommandKind.M if x.node == m.node: m.s_domain ^= x.domain self.__seq.pop(target) # del X return True self._commute_with_following(target) return False def _commute_mz(self, target: int) -> bool: """Perform the commutation of M and Z. Parameters ---------- target : int target command index. this must point to a Z command followed by M command """ z = self.__seq[target] m = self.__seq[target + 1] assert z.kind == CommandKind.Z assert m.kind == CommandKind.M if z.node == m.node: m.t_domain ^= z.domain self.__seq.pop(target) # del Z return True self._commute_with_following(target) return False def _commute_xs(self, target: int) -> None: """Perform the commutation of X and S. Parameters ---------- target : int target command index. this must point to a S command followed by X command """ s = self.__seq[target] x = self.__seq[target + 1] assert s.kind == CommandKind.S assert x.kind == CommandKind.X if s.node in x.domain: x.domain ^= s.domain self._commute_with_following(target) def _commute_zs(self, target: int) -> None: """Perform the commutation of Z and S. Parameters ---------- target : int target command index. this must point to a S command followed by Z command """ s = self.__seq[target] z = self.__seq[target + 1] assert s.kind == CommandKind.S assert z.kind == CommandKind.Z if s.node in z.domain: z.domain ^= s.domain self._commute_with_following(target) def _commute_ms(self, target: int) -> None: """Perform the commutation of M and S. Parameters ---------- target : int target command index. this must point to a S command followed by M command """ s = self.__seq[target] m = self.__seq[target + 1] assert s.kind == CommandKind.S assert m.kind == CommandKind.M if s.node in m.s_domain: m.s_domain ^= s.domain if s.node in m.t_domain: m.t_domain ^= s.domain self._commute_with_following(target) def _commute_ss(self, target: int) -> None: """Perform the commutation of two S commands. Parameters ---------- target : int target command index. this must point to a S command followed by S command """ s1 = self.__seq[target] s2 = self.__seq[target + 1] assert s1.kind == CommandKind.S assert s2.kind == CommandKind.S if s1.node in s2.domain: s2.domain ^= s1.domain self._commute_with_following(target) def _commute_with_following(self, target: int) -> None: """Perform the commutation of two consecutive commands that commutes. commutes the target command with the following command. Parameters ---------- target : int target command index """ a = self.__seq[target + 1] self.__seq.pop(target + 1) self.__seq.insert(target, a) def _commute_with_preceding(self, target: int) -> None: """Perform the commutation of two consecutive commands that commutes. commutes the target command with the preceding command. Parameters ---------- target : int target command index """ a = self.__seq[target - 1] self.__seq.pop(target - 1) self.__seq.insert(target, a) def _move_n_to_left(self) -> None: """Move all 'N' commands to the start of the sequence. N can be moved to the start of sequence without the need of considering commutation relations. """ new_seq = [] n_list = [] for cmd in self.__seq: if cmd.kind == CommandKind.N: n_list.append(cmd) else: new_seq.append(cmd) n_list.sort(key=lambda n_cmd: n_cmd.node) self.__seq = n_list + new_seq def _move_byproduct_to_right(self) -> None: """Move the byproduct commands to the end of sequence, using the commutation relations implemented in graphix.Pattern class.""" # First, we move all X commands to the end of sequence index = len(self.__seq) - 1 x_limit = len(self.__seq) - 1 while index > 0: if self.__seq[index].kind == CommandKind.X: index_x = index while index_x < x_limit: cmd = self.__seq[index_x + 1] kind = cmd.kind if kind == CommandKind.E: move = self._commute_ex(index_x) if move: x_limit += 1 # addition of extra Z means target must be increased index_x += 1 elif kind == CommandKind.M: search = self._commute_mx(index_x) if search: x_limit -= 1 # XM commutation rule removes X command break else: self._commute_with_following(index_x) index_x += 1 else: x_limit -= 1 index -= 1 # then, move Z to the end of sequence in front of X index = x_limit z_limit = x_limit while index > 0: if self.__seq[index].kind == CommandKind.Z: index_z = index while index_z < z_limit: cmd = self.__seq[index_z + 1] if cmd.kind == CommandKind.M: search = self._commute_mz(index_z) if search: z_limit -= 1 # ZM commutation rule removes Z command break else: self._commute_with_following(index_z) index_z += 1 index -= 1 def _move_e_after_n(self) -> None: """Move all E commands to the start of sequence, before all N commands. assumes that _move_n_to_left() method was called.""" moved_e = 0 target = self._find_op_to_be_moved(CommandKind.E, skipnum=moved_e) while target is not None: if (target == 0) or ( self.__seq[target - 1].kind == CommandKind.N or self.__seq[target - 1].kind == CommandKind.E ): moved_e += 1 target = self._find_op_to_be_moved(CommandKind.E, skipnum=moved_e) continue self._commute_with_preceding(target) target -= 1 def extract_signals(self) -> dict[int, set[int]]: """Extract 't' domain of measurement commands, turn them into signal 'S' commands and add to the command sequence. This is used for shift_signals() method. """ signal_dict = {} pos = 0 while pos < len(self.__seq): cmd = self.__seq[pos] if cmd.kind == CommandKind.M: extracted_signal = extract_signal(cmd.plane, cmd.s_domain, cmd.t_domain) if extracted_signal.signal: self.__seq.insert(pos + 1, command.S(node=cmd.node, domain=extracted_signal.signal)) cmd.s_domain = extracted_signal.s_domain cmd.t_domain = extracted_signal.t_domain pos += 1 signal_dict[cmd.node] = extracted_signal.signal pos += 1 return signal_dict def _extract_dependency(self) -> dict[int, set[int]]: """Get dependency (byproduct correction & dependent measurement) structure of nodes in the graph (resource) state, according to the pattern. This is used to determine the optimum measurement order. Returns ------- dependency : dict of set index is node number. all nodes in the each set must be measured before measuring """ nodes = self.extract_nodes() dependency: dict[int, set[int]] = {i: set() for i in nodes} for cmd in self.__seq: if cmd.kind == CommandKind.M: dependency[cmd.node] |= cmd.s_domain | cmd.t_domain # Use of `==` here for mypy elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 dependency[cmd.node] |= cmd.domain return dependency @staticmethod def update_dependency(measured: AbstractSet[int], dependency: dict[int, set[int]]) -> None: """Remove measured nodes from the 'dependency'. Parameters ---------- measured: set of int measured nodes. dependency: dict of set which is produced by `_extract_dependency` Returns ------- dependency: dict of set updated dependency information """ for i in dependency: dependency[i] -= measured def extract_partial_order_layers(self) -> tuple[frozenset[int], ...]: """Extract the measurement order of the pattern in the form of layers. This method standardizes the pattern, builds a directed acyclical graph (DAG) from measurement and correction domains, and then performs a topological sort. Returns ------- tuple[frozenset[int], ...] Measurement partial order between the pattern's nodes in a layer form. Raises ------ RunnabilityError If the pattern is not runnable. Notes ----- - This function wraps :func:`optimization.StandardizedPattern.extract_partial_order_layers`, and the returned object is described in the notes of this method. - See :func:`optimization.StandardizedPattern.extract_causal_flow` for additional information on why it is required to standardized the pattern to extract the partial order layering. """ return optimization.StandardizedPattern.from_pattern(self).extract_partial_order_layers() def extract_causal_flow(self) -> CausalFlow[Measurement]: """Extract the causal flow structure from the current measurement pattern. This method does not call the flow-extraction routine on the underlying open graph, but constructs the flow from the pattern corrections instead. Returns ------- CausalFlow[Measurement] The causal flow associated with the current pattern. Raises ------ FlowError If the pattern: - Contains measurements in forbidden planes (XZ or YZ), - Is empty, or - Induces a correction function and a partial order which fail the well-formedness checks for a valid causal flow. ValueError If `N` commands in the pattern do not represent a |+⟩ state or if the pattern corrections form closed loops. Notes ----- - See :func:`optimization.StandardizedPattern.extract_causal_flow` for additional information on why it is required to standardized the pattern to extract a causal flow. - Applying the chain ``Pattern.extract_causal_flow().to_corrections().to_pattern()`` to a strongly deterministic pattern returns a new pattern implementing the same unitary transformation. This equivalence holds as long as the original pattern contains no Clifford commands, since those are discarded during open-graph extraction. """ return optimization.StandardizedPattern.from_pattern(self).extract_causal_flow() def extract_gflow(self) -> GFlow[Measurement]: """Extract the generalized flow (gflow) structure from the current measurement pattern. This method does not call the flow-extraction routine on the underlying open graph, but constructs the gflow from the pattern corrections instead. Returns ------- GFlow[Measurement] The gflow associated with the current pattern. Raises ------ FlowError If the pattern is empty or if the extracted structure does not satisfy the well-formedness conditions required for a valid gflow. ValueError If `N` commands in the pattern do not represent a |+⟩ state or if the pattern corrections form closed loops. Notes ----- The notes provided in :func:`self.extract_causal_flow` apply here as well. """ return optimization.StandardizedPattern.from_pattern(self).extract_gflow() def extract_xzcorrections(self) -> XZCorrections[Measurement]: """Extract the XZ-corrections from the current measurement pattern. Returns ------- XZCorrections[Measurement] The XZ-corrections associated with the current pattern. Raises ------ XZCorrectionsError If the extracted correction dictionaries are not well formed. ValueError If `N` commands in the pattern do not represent a |+⟩ state or if the pattern corrections form closed loops. Notes ----- To ensure that applying the chain ``Pattern.extract_xzcorrections().to_pattern()`` to a strongly deterministic pattern returns a new pattern implementing the same unitary transformation, XZ-corrections must be extracted from a standardized pattern. This requirement arises for the same reason that flow extraction also operates correctly on standardized patterns only. This equivalence holds as long as the original pattern contains no Clifford commands, since those are discarded during open-graph extraction. See docstring in :func:`optimization.StandardizedPattern.extract_gflow` for additional information. """ return optimization.StandardizedPattern.from_pattern(self).extract_xzcorrections() def _measurement_order_depth(self) -> list[int]: """Obtain a measurement order which reduces the depth of a pattern. Returns ------- list[int] optimal measurement order for parallel computing """ partial_order_layers = self.extract_partial_order_layers() return list(itertools.chain(*reversed(partial_order_layers[1:]))) @staticmethod def connected_edges(node: int, edges: set[tuple[int, int]]) -> set[tuple[int, int]]: """Search not activated edges connected to the specified node. Returns ------- connected: set of tuple set of connected edges """ connected = set() for edge in edges: if edge[0] == node or edge[1] == node: connected |= {edge} return connected def _measurement_order_space(self) -> list[int]: """Determine measurement order that heuristically optimises the max_space of a pattern. Returns ------- meas_order: list of int sub-optimal measurement order for classical simulation """ graph = self.extract_graph() nodes = set(graph.nodes) edges = set(graph.edges) not_measured = nodes - set(self.output_nodes) dependency = self._extract_dependency() self.update_dependency(self.results.keys(), dependency) meas_order = [] removable_edges = set() while not_measured: min_edges = len(nodes) + 1 next_node = -1 for i in not_measured: if not dependency[i]: connected_edges = self.connected_edges(i, edges) if min_edges > len(connected_edges): min_edges = len(connected_edges) next_node = i removable_edges = connected_edges if not (next_node > -1): print(next_node) assert next_node > -1 meas_order.append(next_node) self.update_dependency({next_node}, dependency) not_measured -= {next_node} edges -= removable_edges return meas_order def sort_measurement_commands(self, meas_order: list[int]) -> list[command.M]: """Convert measurement order to sequence of measurement commands. Parameters ---------- meas_order: list of int optimal measurement order. Returns ------- meas_cmds: list of command sorted measurement commands """ meas_dict = self.extract_measurement_commands() return [meas_dict[i] for i in meas_order] def extract_measurement_commands(self) -> dict[int, command.M]: """Return a dictionary mapping nodes to measurement commands. Returns ------- meas_dict : dict[int, command.M] measurement commands indexed by nodes """ return {cmd.node: cmd for cmd in self if cmd.kind == CommandKind.M}
[docs] def compute_max_degree(self) -> int: """Get max degree of a pattern. Returns ------- max_degree : int max degree of a pattern """ graph = self.extract_graph() degree = graph.degree() assert isinstance(degree, nx.classes.reportviews.DiDegreeView) degrees = dict(degree).values() if len(degrees) == 0: return 0 return int(max(degrees))
[docs] def extract_graph(self) -> nx.Graph[int]: """Return the graph state from the command sequence, extracted from 'N' and 'E' commands. Returns ------- graph_state: nx.Graph[int] """ graph: nx.Graph[int] = nx.Graph() graph.add_nodes_from(self.input_nodes) for cmd in self.__seq: if cmd.kind == CommandKind.N: graph.add_node(cmd.node) elif cmd.kind == CommandKind.E: u, v = cmd.nodes if graph.has_edge(u, v): graph.remove_edge(u, v) else: graph.add_edge(u, v) return graph
[docs] def extract_nodes(self) -> set[int]: """Return the set of nodes of the pattern.""" nodes = set(self.input_nodes) for cmd in self.__seq: if cmd.kind == CommandKind.N: nodes.add(cmd.node) return nodes
def extract_isolated_nodes(self) -> set[int]: """Get isolated nodes. Returns ------- isolated_nodes : set[int] set of the isolated nodes """ graph = self.extract_graph() return {node for node, d in graph.degree if d == 0} def extract_opengraph(self) -> OpenGraph[Measurement]: """Extract the underlying resource-state open graph from the pattern. Returns ------- OpenGraph[Measurement] Raises ------ ValueError If `N` commands in the pattern do not represent a |+⟩ state. Notes ----- This operation loses all the information on the Clifford commands. """ nodes = set(self.input_nodes) edges: set[tuple[int, int]] = set() measurements: dict[int, Measurement] = {} for cmd in self.__seq: if cmd.kind == CommandKind.N: if cmd.state != BasicStates.PLUS: raise ValueError( f"Open graph extraction requires N commands to represent a |+⟩ state. Error found in {cmd}." ) nodes.add(cmd.node) elif cmd.kind == CommandKind.E: u, v = cmd.nodes if u > v: u, v = v, u edges.symmetric_difference_update({(u, v)}) elif cmd.kind == CommandKind.M: measurements[cmd.node] = Measurement(cmd.angle, cmd.plane) graph = nx.Graph(edges) graph.add_nodes_from(nodes) # Inputs and outputs are casted to `tuple` to replicate the behavior of `:func: graphix.opitmization.StandardizedPattern.extract_opengraph`. return OpenGraph(graph, tuple(self.__input_nodes), tuple(self.__output_nodes), measurements)
[docs] def extract_clifford(self) -> dict[int, Clifford]: """Extract Clifford commands. Returns ------- vops : dict """ return {cmd.node: cmd.clifford for cmd in self.__seq if cmd.kind == CommandKind.C}
[docs] def connected_nodes(self, node: int, prepared: set[int] | None = None) -> list[int]: """Find nodes that are connected to a specified node. These nodes must be in the statevector when the specified node is measured, to ensure correct computation. If connected nodes already exist in the statevector (prepared), then they will be ignored as they do not need to be prepared again. Parameters ---------- node : int node index prepared : list list of node indices, which are to be ignored Returns ------- node_list : list list of nodes that are entangled with specified node """ if not self.is_standard(): self.standardize() if prepared is None: prepared = set() node_list = [] ind = self._find_op_to_be_moved(CommandKind.E) if ind is not None: # end -> 'node' is isolated cmd = self.__seq[ind] while cmd.kind == CommandKind.E: if cmd.nodes[0] == node: if cmd.nodes[1] not in prepared: node_list.append(cmd.nodes[1]) elif cmd.nodes[1] == node and cmd.nodes[0] not in prepared: node_list.append(cmd.nodes[0]) ind += 1 cmd = self.__seq[ind] return node_list
def correction_commands(self) -> list[command.X | command.Z]: """Return the list of byproduct correction commands.""" assert self.is_standard() # Use of `==` here for mypy return [seqi for seqi in self.__seq if seqi.kind == CommandKind.X or seqi.kind == CommandKind.Z] # noqa: PLR1714
[docs] def parallelize_pattern(self) -> None: """Optimize the pattern to reduce the depth of the computation by gathering measurement commands that can be performed simultaneously. This optimized pattern runs efficiently on GPUs and quantum hardwares with depth (e.g. coherence time) limitations. """ if not self.is_standard(): self.standardize() meas_order = self._measurement_order_depth() self._reorder_pattern(self.sort_measurement_commands(meas_order))
[docs] def minimize_space(self) -> None: """Optimize the pattern to minimize the max_space property of the pattern. The optimized pattern has significantly reduced space requirement (memory space for classical simulation, and maximum simultaneously prepared qubits for quantum hardwares). """ if not self.is_standard(): self.standardize() meas_order = None try: cf = self.extract_causal_flow() except FlowError: meas_order = None else: meas_order = list(itertools.chain(*reversed(cf.partial_order_layers[1:]))) if meas_order is None: meas_order = self._measurement_order_space() self._reorder_pattern(self.sort_measurement_commands(meas_order))
def _reorder_pattern(self, meas_commands: list[command.M]) -> None: """Reorder the command sequence. Parameters ---------- meas_commands : list of command list of measurement ('M') commands """ new = dataclasses.replace( optimization.StandardizedPattern.from_pattern(self), m_list=tuple(meas_commands) ).to_space_optimal_pattern() self.__seq = new.__seq
[docs] def max_space(self) -> int: """Compute the maximum number of nodes that must be present in the graph (graph space) during the execution of the pattern. For statevector simulation, this is equivalent to the maximum memory needed for classical simulation. Returns ------- n_nodes : int max number of nodes present in the graph during pattern execution. """ nodes = len(self.input_nodes) max_nodes = nodes for cmd in self.__seq: if cmd.kind == CommandKind.N: nodes += 1 elif cmd.kind == CommandKind.M: nodes -= 1 max_nodes = max(nodes, max_nodes) return max_nodes
def space_list(self) -> list[int]: """Return the list of the number of nodes present in the graph (space) during each step of execution of the pattern (for N and M commands). Returns ------- N_list : list time evolution of 'space' at each 'N' and 'M' commands of pattern. """ nodes = 0 n_list = [] for cmd in self.__seq: if cmd.kind == CommandKind.N: nodes += 1 n_list.append(nodes) elif cmd.kind == CommandKind.M: nodes -= 1 n_list.append(nodes) return n_list @overload def simulate_pattern( self, backend: StatevectorBackend | Literal["statevector"] = "statevector", input_state: State | Statevec | Iterable[State] | Iterable[ExpressionOrSupportsComplex] | Iterable[Iterable[ExpressionOrSupportsComplex]] = ..., rng: Generator | None = ..., **kwargs: Any, ) -> Statevec: ... @overload def simulate_pattern( self, backend: DensityMatrixBackend | Literal["densitymatrix"], input_state: State | DensityMatrix | Iterable[State] | Iterable[ExpressionOrSupportsComplex] | Iterable[Iterable[ExpressionOrSupportsComplex]] = ..., rng: Generator | None = ..., **kwargs: Any, ) -> DensityMatrix: ... @overload def simulate_pattern( self, backend: TensorNetworkBackend | Literal["tensornetwork", "mps"], input_state: State | Iterable[State] | Iterable[ExpressionOrSupportsComplex] | Iterable[Iterable[ExpressionOrSupportsComplex]] = ..., rng: Generator | None = ..., **kwargs: Any, ) -> MBQCTensorNet: ... @overload def simulate_pattern( self, backend: Backend[_StateT_co], input_state: Data = ..., rng: Generator | None = ..., **kwargs: Any, ) -> _StateT_co: ...
[docs] def simulate_pattern( self, backend: Backend[_StateT_co] | _BackendLiteral = "statevector", input_state: Data = BasicStates.PLUS, rng: Generator | None = None, **kwargs: Any, ) -> _StateT_co | _BuiltinBackendState: """Simulate the execution of the pattern by using :class:`graphix.simulator.PatternSimulator`. Available backend: ['statevector', 'densitymatrix', 'tensornetwork'] Parameters ---------- backend : str optional parameter to select simulator backend. rng: Generator, optional Random-number generator for measurements. This generator is used only in case of random branch selection (see :class:`RandomBranchSelector`). kwargs: keyword args for specified backend. Returns ------- state : quantum state representation for the selected backend. .. seealso:: :class:`graphix.simulator.PatternSimulator` """ sim: PatternSimulator[_StateT_co] = PatternSimulator(self, backend=backend, **kwargs) sim.run(input_state, rng=rng) return sim.backend.state
def remove_input_nodes(self) -> None: """Remove the input nodes from the pattern and replace them with N commands. This removes the possibility of choosing the input state, fixing the input state to the plus state. .. seealso:: :class:`graphix.command.N` """ self.__seq[0:0] = [command.N(node=node) for node in self.input_nodes] empty_nodes: list[int] = [] self.__input_nodes = empty_nodes
[docs] def perform_pauli_measurements(self, ignore_pauli_with_deps: bool = False) -> None: """Perform Pauli measurements in the pattern using efficient stabilizer simulator. Parameters ---------- ignore_pauli_with_deps : bool Optional (*False* by default). If *True*, Pauli measurements with domains depending on other measures are preserved as-is in the pattern. If *False*, all Pauli measurements are preprocessed. Formally, measurements are swapped so that all Pauli measurements are applied first, and domains are updated accordingly. .. seealso:: :func:`measure_pauli` """ if self.input_nodes: raise ValueError("Remove inputs with `self.remove_input_nodes()` before performing Pauli presimulation.") self.__dict__.update(measure_pauli(self, ignore_pauli_with_deps=ignore_pauli_with_deps).__dict__)
[docs] def draw_graph( self, flow_from_pattern: bool = True, show_pauli_measurement: bool = True, show_local_clifford: bool = False, show_measurement_planes: bool = False, show_loop: bool = True, node_distance: tuple[float, float] = (1, 1), figsize: tuple[int, int] | None = None, filename: Path | None = None, ) -> None: """Visualize the underlying graph of the pattern with flow or gflow structure. Parameters ---------- flow_from_pattern : bool If True, the command sequence of the pattern is used to derive flow or gflow structure. If False, only the underlying graph is used. show_pauli_measurement : bool If True, the nodes with Pauli measurement angles are colored light blue. show_local_clifford : bool If True, indexes of the local Clifford operator are displayed adjacent to the nodes. show_measurement_planes : bool If True, measurement planes are displayed adjacent to the nodes. show_loop : bool whether or not to show loops for graphs with gflow. defaulted to True. node_distance : tuple Distance multiplication factor between nodes for x and y directions. figsize : tuple Figure size of the plot. filename : Path | None If not None, filename of the png file to save the plot. If None, the plot is not saved. Default in None. """ graph = self.extract_graph() vin = self.input_nodes vout = self.output_nodes meas_dict = self.extract_measurement_commands() meas_planes = {node: meas.plane for node, meas in meas_dict.items()} meas_angles = {node: meas.angle for node, meas in meas_dict.items()} clifford = self.extract_clifford() vis = GraphVisualizer(graph, vin, vout, meas_planes, meas_angles, clifford) if flow_from_pattern: vis.visualize_from_pattern( pattern=self.copy(), show_pauli_measurement=show_pauli_measurement, show_local_clifford=show_local_clifford, show_measurement_planes=show_measurement_planes, show_loop=show_loop, node_distance=node_distance, figsize=figsize, filename=filename, ) else: vis.visualize( show_pauli_measurement=show_pauli_measurement, show_local_clifford=show_local_clifford, show_measurement_planes=show_measurement_planes, show_loop=show_loop, node_distance=node_distance, figsize=figsize, filename=filename, )
[docs] def to_qasm3(self, filename: Path | str, input_state: dict[int, State] | State = BasicStates.PLUS) -> None: """Export measurement pattern to OpenQASM 3.0 file. See :func:`graphix.qasm3_exporter.pattern_to_qasm3`. Parameters ---------- filename : Path | str File name to export to. Example: ``"filename.qasm"``. input_state : dict[int, State] | State, default BasicStates.PLUS The initial state for each input node. Only ``|0⟩`` or ``|+⟩`` states are supported. """ with Path(filename).with_suffix(".qasm").open("w", encoding="utf-8") as file: file.writelines(pattern_to_qasm3_lines(self, input_state=input_state))
def is_parameterized(self) -> bool: """ Return `True` if there is at least one measurement angle that is not just an instance of `SupportsFloat`. A parameterized pattern is a pattern where at least one measurement angle is an expression that is not a number, typically an instance of `sympy.Expr` (but we don't force to choose `sympy` here). """ return any(not isinstance(cmd.angle, SupportsFloat) for cmd in self if cmd.kind == command.CommandKind.M) def subs(self, variable: Parameter, substitute: ExpressionOrSupportsFloat) -> Pattern: """Return a copy of the pattern where all occurrences of the given variable in measurement angles are substituted by the given value.""" result = self.copy() for cmd in result: if cmd.kind == command.CommandKind.M: cmd.angle = parameter.subs(cmd.angle, variable, substitute) return result def xreplace(self, assignment: Mapping[Parameter, ExpressionOrSupportsFloat]) -> Pattern: """Return a copy of the pattern where all occurrences of the given keys in measurement angles are substituted by the given values in parallel.""" result = self.copy() for cmd in result: if cmd.kind == command.CommandKind.M: cmd.angle = parameter.xreplace(cmd.angle, assignment) return result def copy(self) -> Pattern: """Return a copy of the pattern.""" result = self.__new__(self.__class__) result.__seq = [copy.copy(cmd) for cmd in self.__seq] result.__input_nodes = self.__input_nodes.copy() result.__output_nodes = self.__output_nodes.copy() result.__n_node = self.__n_node result.results = self.results.copy() return result def check_runnability(self) -> None: """Check whether the pattern is runnable. Raises `RunnabilityError` exception if it is not. Notes ----- The runnability check can only guarantee the runnability of MBQC+LC patterns. Patterns that make use of custom `BaseN` and `BaseM` commands can have additional runnability constraints that are not checked by this method. For instance, in the Veriphix implementation of VBQC, blind measurements have hidden domains that cannot be checked. """ active = set(self.input_nodes) measured = set(self.results) def check_active(cmd: Command, node: int) -> None: if node in measured: raise RunnabilityError(cmd, node, RunnabilityErrorReason.AlreadyMeasured) if node not in active: raise RunnabilityError(cmd, node, RunnabilityErrorReason.NotYetActive) def check_measured(cmd: Command, node: int) -> None: if node not in measured: raise RunnabilityError(cmd, node, RunnabilityErrorReason.NotYetMeasured) for cmd in self: if cmd.kind == CommandKind.N: if cmd.node in active: raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.AlreadyActive) if cmd.node in measured: raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.AlreadyMeasured) active.add(cmd.node) elif cmd.kind == CommandKind.E: n0, n1 = cmd.nodes check_active(cmd, n0) check_active(cmd, n1) elif cmd.kind == CommandKind.M: check_active(cmd, cmd.node) if isinstance(cmd, command.M): # `cmd.s_domain` and `cmd.t_domain` are only # defined if the command is an actual `M` command, # which may not be the case if the method is # called with a pattern constructed with another # implementation of `BaseM` (for instance, a blind # pattern from Veriphix). for domain in cmd.s_domain, cmd.t_domain: if cmd.node in domain: raise RunnabilityError(cmd, cmd.node, RunnabilityErrorReason.DomainSelfLoop) for node in domain: check_measured(cmd, node) active.remove(cmd.node) measured.add(cmd.node) # Use of `==` here for mypy elif cmd.kind == CommandKind.X or cmd.kind == CommandKind.Z: # noqa: PLR1714 check_active(cmd, cmd.node) for node in cmd.domain: check_measured(cmd, node) elif cmd.kind == CommandKind.C: check_active(cmd, cmd.node)
class RunnabilityErrorReason(Enum): """Describe the reason for a pattern not being runnable.""" AlreadyActive = enum.auto() """A node is prepared whereas it has already been prepared or it is an input node.""" AlreadyMeasured = enum.auto() """A node is measured for a second time.""" NotYetActive = enum.auto() """A node is entangled, measured or corrected whereas it has not been prepared yet and it is not an input node.""" NotYetMeasured = enum.auto() """A node appears in the domain of a measurement of a correction whereas it has not been measured yet.""" DomainSelfLoop = enum.auto() """A node appears in the domain of its own measurement. This is a particular case of `NotYetMeasured`, introduced to make the error message clearer.""" @dataclass class RunnabilityError(Exception): """Error raised by :method:`Pattern.check_runnability`.""" cmd: Command node: int reason: RunnabilityErrorReason def __str__(self) -> str: """Explain the error.""" if self.reason == RunnabilityErrorReason.AlreadyActive: return f"{self.cmd}: node {self.node} is already active." if self.reason == RunnabilityErrorReason.AlreadyMeasured: return f"{self.cmd}: node {self.node} is already measured." if self.reason == RunnabilityErrorReason.NotYetActive: return f"{self.cmd}: node {self.node} is not yet active." if self.reason == RunnabilityErrorReason.NotYetMeasured: return f"{self.cmd}: node {self.node} is not yet measured." if self.reason == RunnabilityErrorReason.DomainSelfLoop: return f"{self.cmd}: node {self.node} appears in the domain of its own measurement command." assert_never(self.reason)
[docs] def measure_pauli(pattern: Pattern, *, ignore_pauli_with_deps: bool = False) -> Pattern: """Perform Pauli measurement of a pattern by fast graph state simulator. Uses the decorated-graph method implemented in graphix.graphsim to perform the measurements in Pauli bases, and then sort remaining nodes back into pattern together with Clifford commands. Users are required to ensure there are no input nodes with :func:`graphix.pattern.Pattern.remove_input_nodes` before using this function. TODO: non-XY plane measurements in original pattern Parameters ---------- pattern : graphix.pattern.Pattern object ignore_pauli_with_deps : bool Optional (*False* by default). If *True*, Pauli measurements with domains depending on other measures are preserved as-is in the pattern. If *False*, all Pauli measurements are preprocessed. Formally, measurements are swapped so that all Pauli measurements are applied first, and domains are updated accordingly. Returns ------- new_pattern : graphix.Pattern object pattern with Pauli measurement removed. only returned if copy argument is True. .. seealso:: :class:`graphix.pattern.Pattern.remove_input_nodes` .. seealso:: :class:`graphix.graphsim.GraphState` """ pat = Pattern() standardized_pattern = optimization.StandardizedPattern.from_pattern(pattern) if not ignore_pauli_with_deps: standardized_pattern = standardized_pattern.perform_pauli_pushing() output_nodes = set(pattern.output_nodes) graph = standardized_pattern.extract_graph() graph_state = GraphState(nodes=graph.nodes, edges=graph.edges, vops=standardized_pattern.c_dict) results: dict[int, Outcome] = pattern.results to_measure, non_pauli_meas = pauli_nodes(standardized_pattern) if not to_measure: return pattern for cmd in to_measure: pattern_cmd = cmd[0] measurement_basis = cmd[1] # extract signals for adaptive angle. s_signal = 0 t_signal = 0 if measurement_basis.axis == Axis.X: # X measurement is not affected by s_signal t_signal = sum(results[j] for j in pattern_cmd.t_domain) elif measurement_basis.axis == Axis.Y: s_signal = sum(results[j] for j in pattern_cmd.s_domain) t_signal = sum(results[j] for j in pattern_cmd.t_domain) elif measurement_basis.axis == Axis.Z: # Z measurement is not affected by t_signal s_signal = sum(results[j] for j in pattern_cmd.s_domain) else: assert_never(measurement_basis.axis) if int(s_signal % 2) == 1: # equivalent to X byproduct graph_state.h(pattern_cmd.node) graph_state.z(pattern_cmd.node) graph_state.h(pattern_cmd.node) if int(t_signal % 2) == 1: # equivalent to Z byproduct graph_state.z(pattern_cmd.node) basis = measurement_basis if basis.axis == Axis.X: measure = graph_state.measure_x elif basis.axis == Axis.Y: measure = graph_state.measure_y elif basis.axis == Axis.Z: measure = graph_state.measure_z else: assert_never(basis.axis) if basis.sign == Sign.PLUS: results[pattern_cmd.node] = measure(pattern_cmd.node, choice=0) else: results[pattern_cmd.node] = 0 if measure(pattern_cmd.node, choice=1) else 1 # measure (remove) isolated nodes. if they aren't Pauli measurements, # measuring one of the results with probability of 1 should not occur as was possible above for Pauli measurements, # which means we can just choose s=0. We should not remove output nodes even if isolated. isolates = graph_state.isolated_nodes() for node in non_pauli_meas: if (node in isolates) and (node not in output_nodes): graph_state.remove_node(node) results[node] = 0 # update command sequence vops = graph_state.extract_vops() new_seq: list[Command] = [] new_seq.extend(command.N(node=index) for index in set(graph_state.nodes)) new_seq.extend(command.E(nodes=edge) for edge in graph_state.edges) new_seq.extend( cmd.clifford(Clifford(vops[cmd.node])) for cmd in standardized_pattern.m_list if cmd.node in graph_state.nodes ) new_seq.extend( command.C(node=index, clifford=Clifford(vops[index])) for index in pattern.output_nodes if vops[index] != Clifford.I ) new_seq.extend(command.Z(node=node, domain=set(domain)) for node, domain in standardized_pattern.z_dict.items()) new_seq.extend(command.X(node=node, domain=set(domain)) for node, domain in standardized_pattern.x_dict.items()) pat.replace(new_seq, input_nodes=[]) pat.reorder_output_nodes(standardized_pattern.output_nodes) assert pat.n_node == len(graph_state.nodes) pat.results = results return pat
def pauli_nodes(pattern: optimization.StandardizedPattern) -> tuple[list[tuple[command.M, PauliMeasurement]], set[int]]: """Return the list of measurement commands that are in Pauli bases and that are not dependent on any non-Pauli measurements. Parameters ---------- pattern : optimization.StandardizedPattern Returns ------- pauli_node : list list of measures non_pauli_nodes : set[int] """ pauli_node: list[tuple[command.M, PauliMeasurement]] = [] # Nodes that are non-Pauli measured, or pauli measured but depends on pauli measurement non_pauli_node: set[int] = set() for cmd in pattern.m_list: pm = PauliMeasurement.try_from(cmd.plane, cmd.angle) # None returned if the measurement is not in Pauli basis if pm is not None: # Pauli measurement to be removed if pm.axis == Axis.X: if cmd.t_domain & non_pauli_node: # cmd depend on non-Pauli measurement non_pauli_node.add(cmd.node) else: pauli_node.append((cmd, pm)) elif pm.axis == Axis.Y: if (cmd.s_domain | cmd.t_domain) & non_pauli_node: # cmd depend on non-Pauli measurement non_pauli_node.add(cmd.node) else: pauli_node.append((cmd, pm)) elif pm.axis == Axis.Z: if cmd.s_domain & non_pauli_node: # cmd depend on non-Pauli measurement non_pauli_node.add(cmd.node) else: pauli_node.append((cmd, pm)) else: raise ValueError("Unknown Pauli measurement basis") else: non_pauli_node.add(cmd.node) return pauli_node, non_pauli_node def assert_permutation(original: list[int], user: list[int]) -> None: """Check that the provided `user` node list is a permutation from `original`.""" node_set = set(user) if node_set != set(original): raise ValueError(f"{node_set} != {set(original)}") for node in user: if node in node_set: node_set.remove(node) else: raise ValueError(f"{node} appears twice") @dataclass class ExtractedSignal: """Return data structure for `extract_signal`.""" s_domain: set[int] "New `s_domain` for the measure command." t_domain: set[int] "New `t_domain` for the measure command." signal: set[int] "Domain for the shift command." def extract_signal(plane: Plane, s_domain: set[int], t_domain: set[int]) -> ExtractedSignal: """Extract signal from domains.""" if plane == Plane.XY: return ExtractedSignal(s_domain=s_domain, t_domain=set(), signal=t_domain) if plane == Plane.XZ: return ExtractedSignal(s_domain=set(), t_domain=s_domain ^ t_domain, signal=s_domain) if plane == Plane.YZ: return ExtractedSignal(s_domain=set(), t_domain=t_domain, signal=s_domain) assert_never(plane) def shift_outcomes(outcomes: dict[int, Outcome], signal_dict: dict[int, set[int]]) -> dict[int, Outcome]: """Update outcomes with shifted signals. Shifted signals (as returned by the method :func:`Pattern.shift_signals`) affect classical outputs (measurements) while leaving the quantum state invariant. This method updates the given `outcomes` by swapping the measurements affected by signals. This can be used either to transform the value of :data:`Pattern.results` into measurements observed in the unshifted pattern, or vice versa. Parameters ---------- outcomes : dict[int, int] Classical outputs. signal_dict : dict[int, set[int]] For each node, the signal that has been shifted (as returned by :func:`Pattern.shift_signals`). Returns ------- shifted_outcomes : dict[int, int] Classical outputs updated with shifted signals. """ return { node: toggle_outcome(outcome) if sum(outcomes[i] for i in signal_dict.get(node, [])) % 2 == 1 else outcome for node, outcome in outcomes.items() }