Skip to content

Multi-qubit: GHZ phase estimation

In this example, we define a circuit to prepare a GHZ state. This state, ubiquitous in quantum information processing tasks, is highly entangled; and in the context of quantum metrology, saturates the Heisenberg limit.

import itertools

import equinox as eqx
import jax
import jax.numpy as jnp
import seaborn as sns
import ultraplot as uplt
from rich.pretty import pprint

from squint.circuit import Circuit
from squint.ops.base import SharedGate
from squint.ops.dv import Conditional, DiscreteVariableState, HGate, RZGate, XGate
n = 3  # number of qubits
circuit = Circuit(backend="pure")
for i in range(n):
    circuit.add(DiscreteVariableState(wires=(i,), n=(0,)))

circuit.add(HGate(wires=(0,)))
for i in range(n - 1):
    circuit.add(Conditional(gate=XGate, wires=(i, i + 1)))

circuit.add(
    SharedGate(op=RZGate(wires=(0,), phi=0.0 * jnp.pi), wires=tuple(range(1, n))),
    "phase",
)

for i in range(n):
    circuit.add(HGate(wires=(i,)))

pprint(circuit)
Circuit(
  dims=None,
  ops={
│   0:
│   DiscreteVariableState(wires=(0,), n=[(1.0, (0,))]),
│   1:
│   DiscreteVariableState(wires=(1,), n=[(1.0, (0,))]),
│   2:
│   DiscreteVariableState(wires=(2,), n=[(1.0, (0,))]),
│   3:
│   HGate(wires=(0,)),
│   4:
│   Conditional(wires=(0, 1), gate=XGate(wires=(1,))),
│   5:
│   Conditional(wires=(1, 2), gate=XGate(wires=(2,))),
│   'phase':
│   SharedGate(
│     wires=(0, 1, 2),
│     op=RZGate(wires=(0,), phi=weak_f64[]),
│     copies=[RZGate(wires=(1,), phi=None), RZGate(wires=(2,), phi=None)],
│     where=<function <lambda>>,
│     get=<function <lambda>>
│   ),
│   7:
│   HGate(wires=(0,)),
│   8:
│   HGate(wires=(1,)),
│   9:
│   HGate(wires=(2,))
  },
  _backend='pure'
)
params, static = eqx.partition(circuit, eqx.is_inexact_array)
sim = circuit.compile(static, 2, params, optimize="greedy")
phis = jnp.linspace(-jnp.pi, jnp.pi, 100)
params = eqx.tree_at(lambda pytree: pytree.ops["phase"].op.phi, params, phis)

probs = jax.vmap(sim.probabilities.forward)(params)
grads = jax.vmap(sim.probabilities.grad)(params).ops["phase"].op.phi
qfims = jax.vmap(sim.amplitudes.qfim)(params)
cfims = jax.vmap(sim.probabilities.cfim)(params)
colors = sns.color_palette("Set2", n_colors=jnp.prod(jnp.array(probs.shape[1:])))
fig, axs = uplt.subplots(nrows=3, figsize=(6, 4), sharey=False)

for i, idx in enumerate(
    itertools.product(*[list(range(ell)) for ell in probs.shape[1:]])
):
    axs[0].plot(phis, probs[:, *idx], label=f"{idx}", color=colors[i])

    axs[1].plot(phis, grads[:, *idx], label=f"{idx}", color=colors[i])
axs[0].legend()
axs[0].set(ylabel=r"$p(\mathbf{x} | \varphi)$")
axs[1].set(ylabel=r"$\partial_{\varphi} p(\mathbf{x} | \varphi)$")

axs[2].plot(phis, qfims.squeeze(), color=colors[0], label=r"$\mathcal{I}_\varphi^Q$")
axs[2].plot(phis, cfims.squeeze(), color=colors[-1], label=r"$\mathcal{I}_\varphi^C$")
axs[2].set(
    xlabel=r"Phase, $\varphi$",
    ylabel=r"$\mathcal{I}_\varphi^C$",
    ylim=[0, 1.05 * jnp.max(qfims)],
)
axs[2].legend();

img