Skip to content

Optimization

This guide shows you how to optimization a parameterized squint circuit.

import optax
from squint.circuit import Circuit
from squint.base import SharedGate
from squint.ops.dv import DiscreteVariableState, RXGate, RYGate, RZGate, CXGate
from squint.utils import partition_op

def variational_sensor(n_qubits, n_layers):
    """Create a hardware-efficient ansatz variational quantum sensor."""
    circuit = Circuit(backend="pure")

    # Initialize qubits
    for i in range(n_qubits):
        circuit.add(DiscreteVariableState(wires=(i,), n=(0,)))

    # Variational layers
    for layer in range(n_layers):
        # Single-qubit rotations
        for i in range(n_qubits):
            circuit.add(RXGate(wires=(i,), phi=0.0), f"rx_{layer}_{i}")
            circuit.add(RYGate(wires=(i,), phi=0.0), f"ry_{layer}_{i}")

        # Entangling layer
        for i in range(n_qubits - 1):
            circuit.add(CXGate(wires=(i, i+1)))

    # Phase sensing layer
    circuit.add(
        SharedGate(op=RZGate(wires=(0,), phi=0.0 * jnp.pi), wires=tuple(range(1, n_qubits))),
        "phase",
    )

    # Measurement layer
    for i in range(n_qubits):
        circuit.add(RXGate(wires=(i,), phi=0.0), f"meas_x_{i}")
        circuit.add(RYGate(wires=(i,), phi=0.0), f"meas_y_{i}")

    return circuit

Optimization loop

n_qubits = 4
n_layers = 3
n_steps = 100

circuit = variational_sensor(n_qubits, n_layers)

params, static = eqx.partition(circuit, eqx.is_inexact_array)
params_est, params_opt = partition_op(params, "phase")

sim = compile(
    static, dim, params_est, params_opt, **{"optimize": "greedy", "argnum": 0}
)

def loss(params_est, params_opt):
    return sim.probabilities.cfim(params_est, params_opt).squeeze()

@jax.jit
def step(opt_state, params_est, params_opt):
    val, grad = jax.value_and_grad(loss, argnums=1)(params_est, params_opt)
    updates, opt_state = optimizer.update(grad, opt_state, params_opt)
    params_opt = optax.apply_updates(params_opt, updates)
    return params_opt, opt_state, val

# Run optimization
optimizer = optax.chain(optax.adam(learning_rate=1e-3), optax.scale(-1.0))
opt_state = optimizer.init(params_opt)

losses = []
for i in range(n_steps):
    params_opt, opt_state, val = step(opt_state, params_est, params_opt)
    losses.append(val)

circuit = eqx.combine(static, params_est, params_opt)

fig, ax = plt.subplots()
ax.plot(losses)
ax.set(
    xlabel="Optimization step",
    ylabel="Classical Fisher Information"
)