Skip to content

Simulator

SimulatorQuantumAmplitudes dataclass

Simulator object which computes quantities related to the quantum probability amplitudes, including forward pass, gradient computation, and quantum Fisher information matrix calculation.

Attributes:

Name Type Description
forward Callable

Function to compute quantum amplitudes.

grad Callable

Function to compute gradients of quantum amplitudes.

qfim Callable

Function to compute the quantum Fisher information matrix.

Source code in src/squint/simulator.py
@dataclass
class SimulatorQuantumAmplitudes:
    """
    Simulator object which computes quantities related to the quantum probability amplitudes,
    including forward pass, gradient computation,
    and quantum Fisher information matrix calculation.

    Attributes:
        forward (Callable): Function to compute quantum amplitudes.
        grad (Callable): Function to compute gradients of quantum amplitudes.
        qfim (Callable): Function to compute the quantum Fisher information matrix.
    """

    forward: Callable
    grad: Callable
    qfim: Callable

    def jit(self, device: jax.Device = None):
        """
        JIT (just-in-time) compile the simulator methods.
        Args:
            device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
        """
        return SimulatorQuantumAmplitudes(
            forward=jax.jit(self.forward, device=device),
            grad=jax.jit(self.grad, device=device),
            qfim=jax.jit(self.qfim, device=device),
            # qfim=jax.jit(self.qfim, static_argnames=("get",), device=device),
        )

jit(device: jax.Device = None)

JIT (just-in-time) compile the simulator methods. Args: device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.

Source code in src/squint/simulator.py
def jit(self, device: jax.Device = None):
    """
    JIT (just-in-time) compile the simulator methods.
    Args:
        device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
    """
    return SimulatorQuantumAmplitudes(
        forward=jax.jit(self.forward, device=device),
        grad=jax.jit(self.grad, device=device),
        qfim=jax.jit(self.qfim, device=device),
        # qfim=jax.jit(self.qfim, static_argnames=("get",), device=device),
    )

SimulatorClassicalProbabilities dataclass

Simulator object which computes quantities related to the classical probabilities, including forward pass, gradient computation, and classical Fisher information matrix calculation.

Attributes:

Name Type Description
forward Callable

Function to compute classical probabilities.

grad Callable

Function to compute gradients of classical probabilities.

cfim Callable

Function to compute the classical Fisher information matrix.

Source code in src/squint/simulator.py
@dataclass
class SimulatorClassicalProbabilities:
    """
    Simulator object which computes quantities related to the classical probabilities,
    including forward pass, gradient computation,
    and classical Fisher information matrix calculation.

    Attributes:
        forward (Callable): Function to compute classical probabilities.
        grad (Callable): Function to compute gradients of classical probabilities.
        cfim (Callable): Function to compute the classical Fisher information matrix.
    """

    forward: Callable
    grad: Callable
    cfim: Callable

    @beartype
    def jit(self, device: jax.Device = None):
        """
        JIT (just-in-time) compile the simulator methods.
        Args:
            device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
        """
        return SimulatorClassicalProbabilities(
            forward=jax.jit(self.forward, device=device),
            grad=jax.jit(self.grad, device=device),
            cfim=jax.jit(self.cfim, device=device),
            # cfim=jax.jit(self.cfim, static_argnames=("get",), device=device),
        )

jit(device: jax.Device = None)

JIT (just-in-time) compile the simulator methods. Args: device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.

Source code in src/squint/simulator.py
@beartype
def jit(self, device: jax.Device = None):
    """
    JIT (just-in-time) compile the simulator methods.
    Args:
        device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
    """
    return SimulatorClassicalProbabilities(
        forward=jax.jit(self.forward, device=device),
        grad=jax.jit(self.grad, device=device),
        cfim=jax.jit(self.cfim, device=device),
        # cfim=jax.jit(self.cfim, static_argnames=("get",), device=device),
    )

Simulator dataclass

Simulator for quantum circuits, providing callable methods for computing forwar, backward, and Fisher Information matrix calculations on the quantum amplitudes and classical probabilities, given a set of parameters PyTrees

Attributes:

Name Type Description
amplitudes SimulatorQuantumAmplitudes

Object for quantum amplitudes computations.

probabilities SimulatorClassicalProbabilities

Object for classical probabilities computations.

path Any

Path to the simulator, can be used for saving/loading.

info str

Additional information about the simulator.

Source code in src/squint/simulator.py
@dataclass
class Simulator:
    """
    Simulator for quantum circuits, providing callable methods for computing
    forwar, backward, and Fisher Information matrix calculations on the
    quantum amplitudes and classical probabilities, given a set of parameters PyTrees

    Attributes:
        amplitudes (SimulatorQuantumAmplitudes): Object for quantum amplitudes computations.
        probabilities (SimulatorClassicalProbabilities): Object for classical probabilities computations.
        path (Any): Path to the simulator, can be used for saving/loading.
        info (str, optional): Additional information about the simulator.
    """

    amplitudes: SimulatorQuantumAmplitudes
    probabilities: SimulatorClassicalProbabilities
    path: Any
    info: str = None

    def jit(self, device: jax.Device = None):
        """
        JIT (just-in-time) compile the simulator methods.
        Args:
            device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
        """
        if not device:
            device = jax.devices()[0]

        return Simulator(
            amplitudes=self.amplitudes.jit(device=device),
            probabilities=self.probabilities.jit(device=device),
            path=self.path,
            info=self.info,
        )

    def sample(self, key: jr.PRNGKey, params: PyTree, shape: tuple[int, ...]):
        """
        Sample from the quantum circuit using the provided parameters and a random key.
        Args:
            key (jr.PRNGKey): Random key for sampling.
            params (PyTree): Parameters for the quantum circuit, partitioned via `eqx.partition`.
            shape (tuple[int, ...]): Shape of the output samples.
        Returns:
            samples (jnp.ndarray): Samples drawn from the quantum circuit.
        """
        pr = self.probabilities.forward(params)
        idx = jnp.nonzero(pr)
        samples = einops.rearrange(
            jr.choice(key=key, a=jnp.stack(idx), p=pr[idx], shape=shape, axis=1),
            "s ... -> ... s",
        )
        return samples

jit(device: jax.Device = None)

JIT (just-in-time) compile the simulator methods. Args: device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.

Source code in src/squint/simulator.py
def jit(self, device: jax.Device = None):
    """
    JIT (just-in-time) compile the simulator methods.
    Args:
        device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
    """
    if not device:
        device = jax.devices()[0]

    return Simulator(
        amplitudes=self.amplitudes.jit(device=device),
        probabilities=self.probabilities.jit(device=device),
        path=self.path,
        info=self.info,
    )

sample(key: jr.PRNGKey, params: PyTree, shape: tuple[int, ...])

Sample from the quantum circuit using the provided parameters and a random key. Args: key (jr.PRNGKey): Random key for sampling. params (PyTree): Parameters for the quantum circuit, partitioned via eqx.partition. shape (tuple[int, ...]): Shape of the output samples. Returns: samples (jnp.ndarray): Samples drawn from the quantum circuit.

Source code in src/squint/simulator.py
def sample(self, key: jr.PRNGKey, params: PyTree, shape: tuple[int, ...]):
    """
    Sample from the quantum circuit using the provided parameters and a random key.
    Args:
        key (jr.PRNGKey): Random key for sampling.
        params (PyTree): Parameters for the quantum circuit, partitioned via `eqx.partition`.
        shape (tuple[int, ...]): Shape of the output samples.
    Returns:
        samples (jnp.ndarray): Samples drawn from the quantum circuit.
    """
    pr = self.probabilities.forward(params)
    idx = jnp.nonzero(pr)
    samples = einops.rearrange(
        jr.choice(key=key, a=jnp.stack(idx), p=pr[idx], shape=shape, axis=1),
        "s ... -> ... s",
    )
    return samples

quantum_fisher_information_matrix(_forward_amplitudes: Callable, _grad_amplitudes: Callable, *params: PyTree)

Performs the forward pass to compute quantum amplitudes and their gradients, and then calculates the quantum Fisher information matrix. Args: _forward_amplitudes (Callable): Function to compute quantum amplitudes. _grad_amplitudes (Callable): Function to compute gradients of quantum amplitudes. *params (list[PyTree]): Parameters for the quantum circuit, partitioned via eqx.partition. The argnum is already defined in the callables Returns: qfim (jnp.ndarray): Quantum Fisher information matrix.

Source code in src/squint/simulator.py
def quantum_fisher_information_matrix(
    _forward_amplitudes: Callable,
    _grad_amplitudes: Callable,
    # get: Callable,
    *params: PyTree,
):
    """
    Performs the forward pass to compute quantum amplitudes and their gradients,
    and then calculates the quantum Fisher information matrix.
    Args:
        _forward_amplitudes (Callable): Function to compute quantum amplitudes.
        _grad_amplitudes (Callable): Function to compute gradients of quantum amplitudes.
        *params (list[PyTree]): Parameters for the quantum circuit, partitioned via `eqx.partition`.
            The argnum is already defined in the callables
    Returns:
        qfim (jnp.ndarray): Quantum Fisher information matrix."""
    amplitudes = _forward_amplitudes(*params)
    grads, _ = jax.tree.flatten(_grad_amplitudes(*params))
    grads = jnp.stack(grads, axis=0)
    return _quantum_fisher_information_matrix(amplitudes, grads)

classical_fisher_information_matrix(_forward_prob: Callable, _grad_prob: Callable, *params: PyTree)

Performs the forward pass to compute classical probabilities and their gradients, and then calculates the classical Fisher information matrix. Args: _forward_prob (Callable): Function to compute classical probabilities. _grad_prob (Callable): Function to compute gradients of classical probabilities. *params (list[PyTree]): Parameters for the quantum circuit, partitioned via eqx.partition. The argnum is already defined in the callables Returns: cfim (jnp.ndarray): Classical Fisher information matrix.

Source code in src/squint/simulator.py
def classical_fisher_information_matrix(
    _forward_prob: Callable,
    _grad_prob: Callable,
    # get: Callable,
    *params: PyTree,
):
    """
    Performs the forward pass to compute classical probabilities and their gradients,
    and then calculates the classical Fisher information matrix.
    Args:
        _forward_prob (Callable): Function to compute classical probabilities.
        _grad_prob (Callable): Function to compute gradients of classical probabilities.
        *params (list[PyTree]): Parameters for the quantum circuit, partitioned via `eqx.partition`.
            The argnum is already defined in the callables
    Returns:
        cfim (jnp.ndarray): Classical Fisher information matrix.
    """
    probs = _forward_prob(*params)
    grads, _ = jax.tree.flatten(_grad_prob(*params))
    grads = jnp.stack(grads, axis=0)
    return _classical_fisher_information_matrix(probs, grads)