Skip to content

Ops

Fock

FockState

Bases: AbstractPureState

Fock state.

Source code in src/squint/ops/fock.py
class FockState(AbstractPureState):
    r"""
    Fock state.
    """

    n: Sequence[tuple[complex, Sequence[int]]]

    @beartype
    def __init__(
        self,
        wires: Sequence[WiresTypes],
        n: Sequence[int] | Sequence[tuple[complex | float, Sequence[int]]] = None,
    ):
        super().__init__(wires=wires)
        if n is None:
            n = [(1.0, (0,) * len(wires))]  # initialize to |vac> = |0, 0, ...> state
        elif is_bearable(n, Sequence[int]):
            n = [(1.0, n)]
        elif is_bearable(n, Sequence[tuple[complex | float, Sequence[int]]]):
            norm = jnp.sum(jnp.abs(jnp.array([amp for amp, wires in n])) ** 2)
            n = [(amp / jnp.sqrt(norm).item(), wires) for amp, wires in n]
        self.n = paramax.non_trainable(n)
        return

    def __call__(self, dim: int):
        return sum(
            [
                jnp.zeros(shape=(dim,) * len(self.wires)).at[*term[1]].set(term[0])
                for term in self.n
            ]
        )

FixedEnergyFockState

Bases: AbstractPureState

Fixed energy Fock superposition.

Source code in src/squint/ops/fock.py
class FixedEnergyFockState(AbstractPureState):
    r"""
    Fixed energy Fock superposition.
    """

    weights: ArrayLike
    phases: ArrayLike
    n: int
    bases: Sequence[tuple[complex | float, Sequence[int]]]

    @beartype
    def __init__(
        self,
        wires: Sequence[WiresTypes],
        n: int = 1,
        weights: Optional[ArrayLike] = None,
        phases: Optional[ArrayLike] = None,
        key: Optional[jaxtyping.PRNGKeyArray] = None,
    ):
        super().__init__(wires=wires)

        def fixed_energy_states(length, energy):
            if length == 1:
                yield (energy,)
            else:
                for value in range(energy + 1):
                    for permutation in fixed_energy_states(length - 1, energy - value):
                        yield (value,) + permutation

        self.n = n
        self.bases = list(fixed_energy_states(len(wires), n))
        if not weights:
            weights = jnp.ones(shape=(len(self.bases),))
            # weights = jnp.linspace(1.0, 2.0, len(self.bases))
        if not phases:
            phases = jnp.zeros(shape=(len(self.bases),))
            # phases = jnp.linspace(1.0, 2.0, len(self.bases))

        if key is not None:
            subkeys = jr.split(key, 2)
            weights = jr.normal(subkeys[0], shape=weights.shape)
            phases = jr.normal(subkeys[1], shape=phases.shape)

        self.weights = weights
        self.phases = phases
        return

    def __call__(self, dim: int):
        return jnp.einsum(
            "i, i... -> ...",
            jnp.exp(1j * self.phases) * jnp.sqrt(jax.nn.softmax(self.weights)),
            jnp.array(
                [
                    jnp.zeros(shape=(dim,) * len(self.wires)).at[*basis].set(1.0)
                    for basis in self.bases
                ]
            ),
        )

TwoModeWeakThermalState

Bases: AbstractMixedState

Two-mode weak coherent source.

Source code in src/squint/ops/fock.py
class TwoModeWeakThermalState(AbstractMixedState):
    r"""
    Two-mode weak coherent source.
    """

    g: ArrayLike
    phi: ArrayLike
    epsilon: ArrayLike

    @beartype
    def __init__(
        self,
        wires: Sequence[WiresTypes],
        epsilon: float,
        g: float,
        phi: float,
    ):
        super().__init__(wires=wires)
        self.epsilon = jnp.array(epsilon)
        self.g = jnp.array(g)
        self.phi = jnp.array(phi)
        return

    def __call__(self, dim: int):
        assert len(self.wires) == 2, "not correct wires"
        # assert dim == 2, "not correct dim"
        rho = jnp.zeros(shape=(dim, dim, dim, dim), dtype=jnp.complex128)
        rho = rho.at[0, 0, 0, 0].set(1 - self.epsilon)
        rho = rho.at[0, 1, 0, 1].set(self.epsilon / 2)
        rho = rho.at[1, 0, 1, 0].set(self.epsilon / 2)
        rho = rho.at[0, 1, 1, 0].set(self.g * jnp.exp(1j * self.phi) * self.epsilon / 2)
        rho = rho.at[1, 0, 0, 1].set(
            self.g * jnp.exp(-1j * self.phi) * self.epsilon / 2
        )
        return rho

TwoModeSqueezingGate

Bases: AbstractGate

TwoModeSqueezingGate

Source code in src/squint/ops/fock.py
class TwoModeSqueezingGate(AbstractGate):
    r"""
    TwoModeSqueezingGate
    """

    r: ArrayLike
    phi: ArrayLike

    @beartype
    def __init__(self, wires: tuple[WiresTypes, WiresTypes], r, phi):
        super().__init__(wires=wires)
        self.r = jnp.asarray(r)
        self.phi = jnp.asarray(phi)
        return

    def __call__(self, dim: int):
        s2_l = jnp.kron(create(dim), create(dim))
        s2_r = jnp.kron(destroy(dim), destroy(dim))
        u = jax.scipy.linalg.expm(
            1j * jnp.tanh(self.r) * (jnp.conjugate(self.phi) * s2_l - self.phi * s2_r)
        ).reshape(4 * (dim,))
        return u

BeamSplitter

Bases: AbstractGate

Beam splitter

Source code in src/squint/ops/fock.py
class BeamSplitter(AbstractGate):
    r"""
    Beam splitter
    """

    r: ArrayLike

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes, WiresTypes],
        r: float | ArrayLike = jnp.pi / 4,
    ):
        super().__init__(wires=wires)
        self.r = jnp.array(r)
        return

    def __call__(self, dim: int):
        bs_l = jnp.kron(create(dim), destroy(dim))
        bs_r = jnp.kron(destroy(dim), create(dim))
        u = jax.scipy.linalg.expm(1j * self.r * (bs_l + bs_r)).reshape(4 * (dim,))
        return u  # TODO: this is correct for the `mixed` backend, while... DONE: this should be correct for both backends now

LinearOpticalUnitaryGate

Bases: AbstractGate

Source code in src/squint/ops/fock.py
class LinearOpticalUnitaryGate(AbstractGate):
    r""" """

    unitary_modes: ArrayLike  # unitary which acts on the optical modes

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes, ...],
        unitary_modes: ArrayLike,
    ):
        super().__init__(wires=wires)
        assert unitary_modes.shape == (len(wires), len(wires)), (
            "Number of wires does not match mode unitary shape."
        )
        assert jnp.allclose(
            unitary_modes @ unitary_modes.T.conj(), jnp.eye(len(wires), len(wires))
        ), "`unitary_modes` arg is not unitary."
        self.unitary_modes = unitary_modes
        return

    def _init_static_arrays(self, dim: int):
        """
        For each number of photons, n > 0, we generate all combinations and compute the A_ij square matrices.
        Next, we compute the U_{i,j} coefficients for input i and output j bases by taking the permanent.
        We insert these coefficients into the larger U_{i,j} matrix, which is a square matrix of size cut^m x cut^m.
        We need to do this for all 0 < n <= cut, where cut is the cutoff for the number of photons in each mode.
        """
        # create the indices for input and output as an ndarray. [m, dim, dim, dim, ..., dim]; we use this for the factorial and for referencing the index ordering
        m = len(self.wires)

        # generate the indices for calculating all of the factorial normalization arrays
        idx_i = jnp.indices((dim,) * m)
        idx_j = copy.deepcopy(idx_i)

        idx_i_fac = jnp.prod(jax.scipy.special.factorial(idx_i), axis=0)
        idx_j_fac = jnp.prod(jax.scipy.special.factorial(idx_j), axis=0)

        factorial_weight = jnp.einsum(
            "i,j->ij",
            1 / jnp.sqrt(idx_i_fac).reshape(dim**m),
            1 / jnp.sqrt(idx_j_fac).reshape(dim**m),
        ).reshape((dim,) * m * 2)

        # for each n <= cut, generate all combinations of indices
        # this is done for each n (number of excitations), rather than all possible number bases at once,
        # as the Aij matrix is not square, and the interferometer is by definition linear
        inds_n_i = [list(get_fixed_sum_tuples(m, n)) for n in range(dim)]
        inds_n_j = copy.deepcopy(inds_n_i)

        def pairwise_combinations(A, B):
            return jnp.stack(
                [
                    A[:, None, :].repeat(B.shape[0], axis=1),
                    B[None, :, :].repeat(A.shape[0], axis=0),
                ],
                axis=2,
            ).reshape(-1, 2, A.shape[1])

        # calculate all pairs of input & output bases, along with their transition indices for creating the Aij matrices
        pairs, transition_inds = [], []
        for n in range(dim):
            p = pairwise_combinations(jnp.array(inds_n_i[n]), jnp.array(inds_n_j[n]))
            pairs.append(p.reshape(p.shape[0], -1))

            t_inds = compile_Aij_indices(inds_n_i[n], inds_n_j[n], m, n)
            transition_inds.append(t_inds)

        return transition_inds, pairs, factorial_weight

    def __call__(self, dim: int):
        # generate all of the static arrays for the indices, transition indices to create Aij for all n
        # and the factorial normalization array
        transition_inds, pairs, factorial_weight = self._init_static_arrays(dim)

        # map the unitary acting on the modes (m x m) to the unitary acting on number states,
        # computed as the Perm[Aij] for all combinations of i and j number bases
        def map_unitary(unitary_modes):
            unitary_number = jnp.zeros((dim,) * 2 * len(self.wires), dtype=jnp.complex_)
            for n in range(dim):
                coefficients = compute_transition_amplitudes(
                    unitary_modes, transition_inds[n]
                )
                unitary_number = unitary_number.at[tuple(pairs[n].T)].set(
                    coefficients.flatten()
                )
            unitary_number = unitary_number * factorial_weight
            return unitary_number

        unitary_number = map_unitary(self.unitary_modes)

        return unitary_number

Phase

Bases: AbstractGate

Phase gate.

Source code in src/squint/ops/fock.py
class Phase(AbstractGate):
    r"""
    Phase gate.
    """

    phi: ArrayLike

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes] = (0,),
        phi: float | int | ArrayLike = 0.0,
    ):
        super().__init__(wires=wires)
        self.phi = jnp.array(phi)
        return

    def __call__(self, dim: int):
        return jnp.diag(jnp.exp(1j * bases(dim) * self.phi))

Discrete variable

DiscreteVariableState

Bases: AbstractPureState

A pure quantum state for a discrete variable system.

\(\ket{\rho} \in \sum_{(a_i, \textbf{i}) \in n} a_i \ket{\textbf{i}}\)

Source code in src/squint/ops/dv.py
class DiscreteVariableState(AbstractPureState):
    r"""
    A pure quantum state for a discrete variable system.

    $\ket{\rho} \in \sum_{(a_i, \textbf{i}) \in n} a_i \ket{\textbf{i}}$
    """

    n: Sequence[
        tuple[complex, Sequence[int]]
    ]  # todo: add superposition as n, using second typehint

    @beartype
    def __init__(
        self,
        wires: Sequence[WiresTypes],
        n: Sequence[int] | Sequence[tuple[complex | float, Sequence[int]]] = None,
    ):
        super().__init__(wires=wires)
        if n is None:
            n = [(1.0, (0,) * len(wires))]  # initialize to |0, 0, ...> state
        elif is_bearable(n, Sequence[int]):
            n = [(1.0, n)]
        elif is_bearable(n, Sequence[tuple[complex | float, Sequence[int]]]):
            norm = jnp.sqrt(jnp.sum(jnp.array([i[0] for i in n]))).item()
            n = [(amp / norm, wires) for amp, wires in n]
        self.n = paramax.non_trainable(n)
        return

    def __call__(self, dim: int):
        return sum(
            [
                jnp.zeros(shape=(dim,) * len(self.wires)).at[*term[1]].set(term[0])
                for term in self.n
            ]
        )

MaximallyMixedState

Bases: AbstractMixedState

Source code in src/squint/ops/dv.py
class MaximallyMixedState(AbstractMixedState):
    r""" """

    @beartype
    def __init__(
        self,
        wires: Sequence[WiresTypes],
    ):
        super().__init__(wires=wires)

    def __call__(self, dim: int):
        d = dim ** len(self.wires)
        identity = jnp.eye(d, dtype=jnp.complex128) / d
        tensor = identity.reshape([dim] * len(self.wires) * 2)
        return tensor

XGate

Bases: AbstractGate

The generalized shift operator, which when dim = 2 corresponds to the standard \(X\) gate.

\(U = \sum_{k=1}^d \ket{k} \bra{(k+1) \text{mod} d}\)

Source code in src/squint/ops/dv.py
class XGate(AbstractGate):
    r"""
    The generalized shift operator, which when `dim = 2` corresponds to the standard $X$ gate.

    $U = \sum_{k=1}^d \ket{k} \bra{(k+1) \text{mod} d}$
    """

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes] = (0,),
    ):
        super().__init__(wires=wires)
        return

    def __call__(self, dim: int):
        return jnp.roll(jnp.eye(dim, k=0), shift=1, axis=0)

ZGate

Bases: AbstractGate

The generalized phase operator, which when dim = 2 corresponds to the standard \(Z\) gate.

\(U = \sum_{k=1}^d e^{2i \pi k /d} \ket{k}\bra{k}\)

Source code in src/squint/ops/dv.py
class ZGate(AbstractGate):
    r"""
    The generalized phase operator, which when `dim = 2` corresponds to the standard $Z$ gate.

    $U = \sum_{k=1}^d e^{2i \pi k /d} \ket{k}\bra{k}$
    """

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes] = (0,),
    ):
        super().__init__(wires=wires)
        return

    def __call__(self, dim: int):
        return jnp.diag(jnp.exp(1j * 2 * jnp.pi * jnp.arange(dim) / dim))

HGate

Bases: AbstractGate

The generalized discrete Fourier operator, which when dim = 2 corresponds to the standard \(H\) gate.

\(U = \sum_{j,k=1}^d e^{2 i \pi j k / d} \ket{j}\bra{k}\)

Source code in src/squint/ops/dv.py
class HGate(AbstractGate):
    r"""
    The generalized discrete Fourier operator, which when `dim = 2` corresponds to the standard $H$ gate.

    $U = \sum_{j,k=1}^d e^{2 i \pi j k  / d} \ket{j}\bra{k}$
    """

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes] = (0,),
    ):
        super().__init__(wires=wires)
        return

    def __call__(self, dim: int):
        return jnp.exp(
            1j
            * 2
            * jnp.pi
            / dim
            * jnp.einsum("a,b->ab", jnp.arange(dim), jnp.arange(dim))
        ) / jnp.sqrt(dim)

Conditional

Bases: AbstractGate

The generalized conditional operator. If the gate \(U\) is applied conditional on the state, \(U = \sum_{k=1}^d \ket{k} \bra{k} \otimes U^k\)

Source code in src/squint/ops/dv.py
class Conditional(AbstractGate):
    r"""
    The generalized conditional operator.
    If the gate $U$ is applied conditional on the state,
    $U = \sum_{k=1}^d \ket{k} \bra{k} \otimes U^k$
    """

    gate: Union[XGate, ZGate]  # type: ignore

    @beartype
    def __init__(
        self,
        gate: Union[Type[XGate], Type[ZGate]],
        wires: tuple[WiresTypes, WiresTypes] = (0, 1),
    ):
        super().__init__(wires=wires)
        self.gate = gate(wires=(wires[1],))
        return

    def __call__(self, dim: int):
        u = sum(
            [
                jnp.einsum(
                    "ac,bd -> abcd",
                    jnp.zeros(shape=(dim, dim)).at[i, i].set(1.0),
                    jnp.linalg.matrix_power(self.gate(dim=dim), i),
                )
                for i in range(dim)
            ]
        )

        return u

RXGate

Bases: AbstractGate

The qubit RXGate gate

Source code in src/squint/ops/dv.py
class RXGate(AbstractGate):
    r"""
    The qubit RXGate gate
    """

    phi: ArrayLike

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes] = (0,),
        phi: float | int = 0.0,
    ):
        super().__init__(wires=wires)
        self.phi = jnp.array(phi)
        return

    def __call__(self, dim: int):
        assert dim == 2, "RXGate only for dim=2"
        return (
            jnp.cos(self.phi / 2) * basis_operators(dim=2)[3]  # identity
            - 1j * jnp.sin(self.phi / 2) * basis_operators(dim=2)[2]  # X
        )

RYGate

Bases: AbstractGate

The qubit RYGate gate

Source code in src/squint/ops/dv.py
class RYGate(AbstractGate):
    r"""
    The qubit RYGate gate
    """

    phi: ArrayLike

    @beartype
    def __init__(
        self,
        wires: tuple[WiresTypes] = (0,),
        phi: float | int = 0.0,
    ):
        super().__init__(wires=wires)
        self.phi = jnp.array(phi)
        return

    def __call__(self, dim: int):
        assert dim == 2, "RYGate only for dim=2"
        return (
            jnp.cos(self.phi / 2) * basis_operators(dim=2)[3]  # identity
            - 1j * jnp.sin(self.phi / 2) * basis_operators(dim=2)[1]  # Y
        )

Noise

ErasureChannel

Bases: AbstractErasureChannel

Erasure channel/photon loss.

Source code in src/squint/ops/noise.py
class ErasureChannel(AbstractErasureChannel):
    r"""
    Erasure channel/photon loss.
    """

    @beartype
    def __init__(self, wires: tuple[WiresTypes]):
        super().__init__(wires=wires)
        return

    def __call__(self, dim: int):
        subscripts = [
            get_symbol(2 * i) + get_symbol(2 * i + 1) for i in range(len(self.wires))
        ]
        return jnp.einsum(
            f"{','.join(subscripts)} -> {''.join(subscripts)}",
            *(len(self.wires) * [jnp.identity(dim)]),
        )

BitFlipChannel

Bases: AbstractKrausChannel

Qubit bit flip channel.

Source code in src/squint/ops/noise.py
class BitFlipChannel(AbstractKrausChannel):
    r"""
    Qubit bit flip channel.
    """

    p: ArrayLike

    @beartype
    def __init__(self, wires: tuple[WiresTypes], p: float):
        super().__init__(wires=wires)
        self.p = jnp.array(p)
        # self.p = p  #paramax.non_trainable(p)
        return

    def __call__(self, dim: int):
        assert dim == 2
        return jnp.array(
            [
                jnp.sqrt(1 - self.p) * basis_operators(dim=2)[3],  # identity
                jnp.sqrt(self.p) * basis_operators(dim=2)[2],  # X
            ]
        )

PhaseFlipChannel

Bases: AbstractKrausChannel

Qubit phase flip channel.

Source code in src/squint/ops/noise.py
class PhaseFlipChannel(AbstractKrausChannel):
    r"""
    Qubit phase flip channel.
    """

    p: ArrayLike

    @beartype
    def __init__(self, wires: tuple[WiresTypes], p: float):
        super().__init__(wires=wires)
        self.p = jnp.array(p)
        # self.p = p  #paramax.non_trainable(p)
        return

    def __call__(self, dim: int):
        assert dim == 2
        return jnp.array(
            [
                jnp.sqrt(1 - self.p) * basis_operators(dim=2)[3],  # identity
                jnp.sqrt(self.p) * basis_operators(dim=2)[0],  # Z
            ]
        )

DepolarizingChannel

Bases: AbstractKrausChannel

Qubit depolarizing channel.

Source code in src/squint/ops/noise.py
class DepolarizingChannel(AbstractKrausChannel):
    r"""
    Qubit depolarizing channel.
    """

    p: ArrayLike

    @beartype
    def __init__(self, wires: tuple[WiresTypes], p: float):
        super().__init__(wires=wires)
        self.p = jnp.array(p)
        return

    def __call__(self, dim: int):
        assert dim == 2
        return jnp.array(
            [
                jnp.sqrt(1 - 3 * self.p / 4) * basis_operators(dim=2)[3],  # identity
                jnp.sqrt(self.p / 4) * basis_operators(dim=2)[0],  # Z
                jnp.sqrt(self.p / 4) * basis_operators(dim=2)[1],  # Y
                jnp.sqrt(self.p / 4) * basis_operators(dim=2)[2],  # X
            ]
        )

Blocks

brickwork(wires: Sequence[WiresTypes], depth: int, LocalGates: Union[Type[AbstractGate], Sequence[Type[AbstractGate]]], CouplingGate: Type[AbstractGate], periodic: bool = False) -> Block

Create a brickwork block with the specified local and coupling gates.

Parameters:

Name Type Description Default
wires Sequence[WiresTypes]

The wires to apply the gates to.

required
depth int

The depth of the brickwork block.

required
LocalGates Union[Type[AbstractGate], Sequence[Type[AbstractGate]]]

The local gates to apply to each wire.

required
CouplingGate Type[AbstractGate]

The coupling gate to apply to pairs of wires.

required
periodic bool

Whether to use periodic boundary conditions.

False

Returns: Block: A block containing the specified brickwork structure.

Source code in src/squint/blocks.py
@beartype
def brickwork(
    wires: Sequence[WiresTypes],
    depth: int,
    LocalGates: Union[Type[AbstractGate], Sequence[Type[AbstractGate]]],
    CouplingGate: Type[AbstractGate],
    periodic: bool = False,
) -> Block:
    """
    Create a brickwork block with the specified local and coupling gates.

    Args:
        wires (Sequence[WiresTypes]): The wires to apply the gates to.
        depth (int): The depth of the brickwork block.
        LocalGates (Union[Type[AbstractGate], Sequence[Type[AbstractGate]]]): The local gates to apply to each wire.
        CouplingGate (Type[AbstractGate]): The coupling gate to apply to pairs of wires.
        periodic (bool): Whether to use periodic boundary conditions.
    Returns:
        Block: A block containing the specified brickwork structure.
    """
    block = Block()
    pairs1, pairs2 = _chunk_pairs(tuple(wires), periodic=periodic)

    if not is_bearable(LocalGates, Sequence[Type[AbstractGate]]):
        LocalGates = (LocalGates,)

    for _layer in range(depth):
        for wire in wires:
            for Gate in LocalGates:
                block.add(Gate(wires=(wire,)))
        for pairs in (pairs1, pairs2):
            for pair in pairs:
                block.add(CouplingGate(wires=pair))

    return block

brickwork_type(wires: Sequence[WiresTypes], depth: int, ansatz: Literal[hea, rxx, rzz], periodic: bool = False)

Create a brickwork block with the specified ansatz type. Ansatz can be one of 'hea', 'rxx', or 'rzz'. - 'hea' uses RX and RY gates for one-qubit gates and CZ for two-qubit gates. - 'rxx' uses RX and RY gates for one-qubit gates and RXX for two-qubit gates. - 'rzz' uses RZ gates for one-qubit gates and RZZ for two-qubit gates.

Parameters:

Name Type Description Default
wires Sequence[WiresTypes]

The wires to apply the gates to.

required
depth int

The depth of the brickwork block.

required
ansatz Literal[hea, rxx, rzz]

The type of ansatz to use.

required
periodic bool

Whether to use periodic boundary conditions.

False

Returns:

Name Type Description
Block

A block containing the specified brickwork ansatz.

Source code in src/squint/blocks.py
@beartype
def brickwork_type(
    wires: Sequence[WiresTypes],
    depth: int,
    ansatz: Literal["hea", "rxx", "rzz"],
    periodic: bool = False,
):
    """
    Create a brickwork block with the specified ansatz type.
    Ansatz can be one of 'hea', 'rxx', or 'rzz'.
    - 'hea' uses RX and RY gates for one-qubit gates and CZ for two-qubit gates.
    - 'rxx' uses RX and RY gates for one-qubit gates and RXX for two-qubit gates.
    - 'rzz' uses RZ gates for one-qubit gates and RZZ for two-qubit gates.

    Args:
        wires (Sequence[WiresTypes]): The wires to apply the gates to.
        depth (int): The depth of the brickwork block.
        ansatz (Literal['hea', 'rxx', 'rzz']): The type of ansatz to use.
        periodic (bool): Whether to use periodic boundary conditions.

    Returns:
        Block: A block containing the specified brickwork ansatz.
    """
    match ansatz:
        case "hea":
            return brickwork(
                wires=wires,
                depth=depth,
                one_qubit_gates=[dv.RXGate, dv.RYGate],
                two_qubit_gates=dv.CZGate,
                periodic=periodic,
            )
        case "rxx":
            return brickwork(
                wires=wires,
                depth=depth,
                LocalGates=(dv.RXGate, dv.RYGate),
                CouplingGate=dv.RXXGate,
                periodic=periodic,
            )
        case "rzz":
            return brickwork(
                wires=wires,
                depth=depth,
                LocalGates=(dv.RXGate, dv.RYGate),
                CouplingGate=dv.RZZGate,
                periodic=periodic,
            )