Simulator¶
SimulatorQuantumAmplitudes
dataclass
¶
Simulator object which computes quantities related to the quantum probability amplitudes, including forward pass, gradient computation, and quantum Fisher information matrix calculation.
Attributes:
Name | Type | Description |
---|---|---|
forward |
Callable
|
Function to compute quantum amplitudes. |
grad |
Callable
|
Function to compute gradients of quantum amplitudes. |
qfim |
Callable
|
Function to compute the quantum Fisher information matrix. |
Source code in src/squint/simulator.py
jit(device: jax.Device = None)
¶
JIT (just-in-time) compile the simulator methods. Args: device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
Source code in src/squint/simulator.py
SimulatorClassicalProbabilities
dataclass
¶
Simulator object which computes quantities related to the classical probabilities, including forward pass, gradient computation, and classical Fisher information matrix calculation.
Attributes:
Name | Type | Description |
---|---|---|
forward |
Callable
|
Function to compute classical probabilities. |
grad |
Callable
|
Function to compute gradients of classical probabilities. |
cfim |
Callable
|
Function to compute the classical Fisher information matrix. |
Source code in src/squint/simulator.py
jit(device: jax.Device = None)
¶
JIT (just-in-time) compile the simulator methods. Args: device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
Source code in src/squint/simulator.py
Simulator
dataclass
¶
Simulator for quantum circuits, providing callable methods for computing forwar, backward, and Fisher Information matrix calculations on the quantum amplitudes and classical probabilities, given a set of parameters PyTrees
Attributes:
Name | Type | Description |
---|---|---|
amplitudes |
SimulatorQuantumAmplitudes
|
Object for quantum amplitudes computations. |
probabilities |
SimulatorClassicalProbabilities
|
Object for classical probabilities computations. |
path |
Any
|
Path to the simulator, can be used for saving/loading. |
info |
str
|
Additional information about the simulator. |
Source code in src/squint/simulator.py
jit(device: jax.Device = None)
¶
JIT (just-in-time) compile the simulator methods. Args: device (jax.Device, optional): Device to compile the methods on. Defaults to None, which uses the first available device.
Source code in src/squint/simulator.py
sample(key: jr.PRNGKey, params: PyTree, shape: tuple[int, ...])
¶
Sample from the quantum circuit using the provided parameters and a random key.
Args:
key (jr.PRNGKey): Random key for sampling.
params (PyTree): Parameters for the quantum circuit, partitioned via eqx.partition
.
shape (tuple[int, ...]): Shape of the output samples.
Returns:
samples (jnp.ndarray): Samples drawn from the quantum circuit.
Source code in src/squint/simulator.py
quantum_fisher_information_matrix(_forward_amplitudes: Callable, _grad_amplitudes: Callable, *params: PyTree)
¶
Performs the forward pass to compute quantum amplitudes and their gradients,
and then calculates the quantum Fisher information matrix.
Args:
_forward_amplitudes (Callable): Function to compute quantum amplitudes.
_grad_amplitudes (Callable): Function to compute gradients of quantum amplitudes.
*params (list[PyTree]): Parameters for the quantum circuit, partitioned via eqx.partition
.
The argnum is already defined in the callables
Returns:
qfim (jnp.ndarray): Quantum Fisher information matrix.
Source code in src/squint/simulator.py
classical_fisher_information_matrix(_forward_prob: Callable, _grad_prob: Callable, *params: PyTree)
¶
Performs the forward pass to compute classical probabilities and their gradients,
and then calculates the classical Fisher information matrix.
Args:
_forward_prob (Callable): Function to compute classical probabilities.
_grad_prob (Callable): Function to compute gradients of classical probabilities.
*params (list[PyTree]): Parameters for the quantum circuit, partitioned via eqx.partition
.
The argnum is already defined in the callables
Returns:
cfim (jnp.ndarray): Classical Fisher information matrix.