Skip to content

One-qudit: High-dimensional phase estimation

import itertools

import equinox as eqx
import jax
import jax.numpy as jnp
import seaborn as sns
import ultraplot as uplt
from rich.pretty import pprint

from squint.circuit import Circuit
from squint.ops.dv import DiscreteVariableState, HGate, RZGate
dim = 4

circuit = Circuit()

circuit.add(DiscreteVariableState(wires=(0,), n=(0,)))
circuit.add(HGate(wires=(0,)))
circuit.add(RZGate(wires=(0,), phi=0.1 * jnp.pi), "phase")
circuit.add(HGate(wires=(0,)))

params, static = eqx.partition(circuit, eqx.is_inexact_array)

print("The full circuit object:")
pprint(circuit)

print("The parameterized leaves of the Pytree:")
pprint(params)

print("The static leaves of the Pytree:")
pprint(static)
The full circuit object:
Circuit(
  dims=None,
  ops={
│   0:
│   DiscreteVariableState(wires=(0,), n=[(1.0, (0,))]),
│   1:
│   HGate(wires=(0,)),
│   'phase':
│   RZGate(wires=(0,), phi=weak_f64[]),
│   3:
│   HGate(wires=(0,))
  },
  _backend=None
)
The parameterized leaves of the Pytree:
Circuit(
  dims=None,
  ops={
│   0:
│   DiscreteVariableState(wires=(None,), n=[(None, (None,))]),
│   1:
│   HGate(wires=(None,)),
│   'phase':
│   RZGate(wires=(None,), phi=weak_f64[]),
│   3:
│   HGate(wires=(None,))
  },
  _backend=None
)
The static leaves of the Pytree:
Circuit(
  dims=None,
  ops={
│   0:
│   DiscreteVariableState(wires=(0,), n=[(1.0, (0,))]),
│   1:
│   HGate(wires=(0,)),
│   'phase':
│   RZGate(wires=(0,), phi=None),
│   3:
│   HGate(wires=(0,))
  },
  _backend=None
)
params, static = eqx.partition(circuit, eqx.is_inexact_array)
sim = circuit.compile(static, dim, params, optimize="greedy")
phis = jnp.linspace(-jnp.pi, jnp.pi, 100)
params = eqx.tree_at(lambda pytree: pytree.ops["phase"].phi, params, phis)

probs = jax.vmap(sim.probabilities.forward)(params)
grads = jax.vmap(sim.probabilities.grad)(params).ops["phase"].phi
qfims = jax.vmap(sim.amplitudes.qfim)(params)
cfims = jax.vmap(sim.probabilities.cfim)(params)
colors = sns.color_palette("Set2", n_colors=jnp.prod(jnp.array(probs.shape[1:])))
fig, axs = uplt.subplots(nrows=3, figsize=(6, 4), sharey=False)

for i, idx in enumerate(
    itertools.product(*[list(range(ell)) for ell in probs.shape[1:]])
):
    axs[0].plot(phis, probs[:, *idx], label=f"{idx}", color=colors[i])

    axs[1].plot(phis, grads[:, *idx], label=f"{idx}", color=colors[i])
axs[0].legend()
axs[0].set(ylabel=r"$p(\mathbf{x} | \varphi)$")
axs[1].set(ylabel=r"$\partial_{\varphi} p(\mathbf{x} | \varphi)$")

axs[2].plot(phis, qfims.squeeze(), color=colors[0], label=r"$\mathcal{I}_\varphi^Q$")
axs[2].plot(phis, cfims.squeeze(), color=colors[-1], label=r"$\mathcal{I}_\varphi^C$")
axs[2].set(
    xlabel=r"Phase, $\varphi$",
    ylabel=r"$\mathcal{I}_\varphi^C$",
    ylim=[0, 1.05 * jnp.max(qfims)],
)
axs[2].legend();

img