Quantum-enhanced telescopy: Gottesman-Jennewein-Croke scheme¶
Implements the original GJC scheme to compute & reproduce the classical Fisher information of this. Here, a phase shift is encoded into a photon in the superposition of being collected by the left and right telescopes. An ancilla photon is distributed in a quantum network between the telescopes, and enables a quantum interference measurement between the two photon arriving the telescope modes.
Source: 10.1103/PhysRevLett.109.070503
import itertools
import equinox as eqx
import jax
import jax.numpy as jnp
import seaborn as sns
import ultraplot as uplt
from rich.pretty import pprint
from squint.circuit import Circuit
from squint.ops.fock import BeamSplitter, FockState, Phase
from squint.utils import partition_op, print_nonzero_entries
cut = 3 # the photon number truncation for the simulation
circuit = Circuit()
# note: `wires` is a spatial mode in this context (in other contexts this can be a information carrying unit, e.g., a qubit/qudit)
# we add in the stellar photon, which is in an even superposition of spatial modes 0 and 2 (left and right telescopes)
circuit.add(
FockState(
wires=(
0,
2,
),
n=[(1 / jnp.sqrt(2).item(), (1, 0)), (1 / jnp.sqrt(2).item(), (0, 1))],
)
)
# the stellar photon accumulates a phase shift prior to collection by the left telescope.
circuit.add(Phase(wires=(0,), phi=0.01), "phase")
# we add the resources photon, which is in an even superposition of spatial modes 1 and 3
circuit.add(
FockState(
wires=(
1,
3,
),
n=[(1 / jnp.sqrt(2).item(), (1, 0)), (1 / jnp.sqrt(2).item(), (0, 1))],
)
)
# we add the linear optical circuit at each telescope (by default this is a 50-50 beamsplitter)
circuit.add(
BeamSplitter(
wires=(
0,
1,
),
)
)
circuit.add(
BeamSplitter(
wires=(
2,
3,
),
)
)
pprint(circuit)
Circuit(
dims=None,
ops={
│ 0:
│ FockState(
│ wires=(0, 2),
│ n=[(0.7071067811865476, (1, 0)), (0.7071067811865476, (0, 1))]
│ ),
│ 'phase':
│ Phase(wires=(0,), phi=weak_f64[]),
│ 2:
│ FockState(
│ wires=(1, 3),
│ n=[(0.7071067811865476, (1, 0)), (0.7071067811865476, (0, 1))]
│ ),
│ 3:
│ BeamSplitter(wires=(0, 1), r=weak_f64[]),
│ 4:
│ BeamSplitter(wires=(2, 3), r=weak_f64[])
},
_backend=None
)
# we split out the params which can be varied (in this example, it is just the "phase" phi value), and all the static parameters (wires, etc.)
params, static = partition_op(circuit, "phase")
# next we compile the circuit description into function calls, which compute, e.g., the quantum state, probabilities, partial derivates of the quantum state, and partial derivatives of the probabilities
sim = circuit.compile(static, cut, params, optimize="greedy").jit()
pprint(circuit)
Circuit(
dims=None,
ops={
│ 0:
│ FockState(
│ wires=(0, 2),
│ n=[(0.7071067811865476, (1, 0)), (0.7071067811865476, (0, 1))]
│ ),
│ 'phase':
│ Phase(wires=(0,), phi=weak_f64[]),
│ 2:
│ FockState(
│ wires=(1, 3),
│ n=[(0.7071067811865476, (1, 0)), (0.7071067811865476, (0, 1))]
│ ),
│ 3:
│ BeamSplitter(wires=(0, 1), r=weak_f64[]),
│ 4:
│ BeamSplitter(wires=(2, 3), r=weak_f64[])
},
_backend=None
)
ket = sim.amplitudes.grad(params)
prob = sim.probabilities.forward(params)
grad = sim.probabilities.grad(params).ops["phase"].phi
print_nonzero_entries(prob)
Basis: [0 0 0 2], Value: 0.12500000000000006
Basis: [0 0 1 1], Value: 3.7057960554250835e-33
Basis: [0 0 2 0], Value: 0.12499999999999997
Basis: [0 1 0 1], Value: 0.2499937500520832
Basis: [0 1 1 0], Value: 6.249947916840279e-06
Basis: [0 2 0 0], Value: 0.12500000000000003
Basis: [1 0 0 1], Value: 6.249947916840279e-06
Basis: [1 0 1 0], Value: 0.24999375005208316
Basis: [1 1 0 0], Value: 3.705796055425082e-33
Basis: [2 0 0 0], Value: 0.12499999999999994
# we next compute the classical Fisher information
cfi = jnp.sum(grad**2 / (prob + 1e-14))
print(f"The classical Fisher information for `phi` is {cfi}")
# this can also be performed from the `sim` object
cfim = sim.probabilities.cfim(params)
print(f"The classical Fisher information is {cfim}")
The classical Fisher information for `phi` is 0.49999999920001365
The classical Fisher information is [[0.5]]
phis = jnp.linspace(-jnp.pi, jnp.pi, 100)
params = eqx.tree_at(lambda pytree: pytree.ops["phase"].phi, params, phis)
probs = jax.vmap(sim.probabilities.forward)(params)
grads = jax.vmap(sim.probabilities.grad)(params).ops["phase"].phi
qfims = jax.vmap(sim.amplitudes.qfim)(params)
cfims = jax.vmap(sim.probabilities.cfim)(params)
colors = itertools.cycle(sns.color_palette("Set2", n_colors=8))
fig, axs = uplt.subplots(nrows=3, figsize=(6, 4), sharey=False)
for i, idx in enumerate(
itertools.product(*[list(range(ell)) for ell in probs.shape[1:]])
):
if probs[:, *idx].max() < 1e-6:
continue
color = next(colors)
axs[0].plot(phis, probs[:, *idx], label=f"{idx}", color=color)
axs[1].plot(phis, grads[:, *idx], label=f"{idx}", color=color)
axs[0].legend()
axs[0].set(ylabel=r"$p(\mathbf{x} | \varphi)$")
axs[1].set(ylabel=r"$\partial_{\varphi} p(\mathbf{x} | \varphi)$")
axs[2].plot(phis, qfims.squeeze(), color=next(colors), label=r"$\mathcal{I}_\varphi^Q$")
axs[2].plot(phis, cfims.squeeze(), color=next(colors), label=r"$\mathcal{I}_\varphi^C$")
axs[2].set(
xlabel=r"Phase, $\varphi$",
ylabel=r"$\mathcal{I}_\varphi^C$",
ylim=[0, 1.05 * jnp.max(qfims)],
)
axs[2].legend();