Skip to content

Math

per(mtx, column, selected, prod, output=False)

Row expansion for the permanent of matrix mtx. The counter column is the current column, selected is a list of indices of selected rows, and prod accumulates the current product.

Source code in src/squint/ops/math.py
def per(mtx, column, selected, prod, output=False):
    """
    Row expansion for the permanent of matrix mtx.
    The counter column is the current column,
    selected is a list of indices of selected rows,
    and prod accumulates the current product.
    """
    if column == mtx.shape[1]:
        if output:
            print(selected, prod)
        return prod
    else:
        result = 0
        for row in range(mtx.shape[0]):
            if not row in selected:
                result = result + per(
                    mtx, column + 1, selected + [row], prod * mtx[row, column]
                )
        return result

permanent(mat)

Returns the permanent of the matrix mat.

Source code in src/squint/ops/math.py
def permanent(mat):
    """
    Returns the permanent of the matrix mat.
    """
    return per(mat, 0, [], 1)

get_fixed_sum_tuples(length, total)

Generate all tuples of a given length that sum to a specified total.

Source code in src/squint/ops/math.py
def get_fixed_sum_tuples(length, total):
    """Generate all tuples of a given length that sum to a specified total."""
    if length == 1:
        yield (total,)
        return

    for i in range(total + 1):
        for t in get_fixed_sum_tuples(length - 1, total - i):
            yield (i,) + t

compile_Aij_indices(i_s: jnp.array, j_s: jnp.array, m: int, n: int)

Compile all indices for generating the \(A_{ij}\) matrices for all i and j combinations.

Source code in src/squint/ops/math.py
def compile_Aij_indices(i_s: jnp.array, j_s: jnp.array, m: int, n: int):
    """Compile all indices for generating the $A_{ij}$ matrices for all i and j combinations."""
    # checkify.check(
    #     jnp.all(i_s.sum(axis=1) == n), f"Some input bases do not have n={n} photons."
    # )
    # checkify.check(
    #     jnp.all(j_s.sum(axis=1) == n), f"Some output bases do not have n={n} photons."
    # )

    unitary_inds = jnp.indices((m, m))

    def repeated_indices(i_basis: jnp.array, j_basis: jnp.array):
        rectangular = jnp.concat(
            [
                einops.repeat(
                    unitary_inds[:, :, i : i + 1],
                    "ind row col -> ind row (rep col)",
                    rep=i_basis[i],
                )
                for i in range(m)
            ],
            axis=2,
        )

        square = jnp.concat(
            [
                einops.repeat(
                    rectangular[:, i : i + 1, :],
                    "ind row col -> ind (rep row) col",
                    rep=j_basis[i],
                )
                for i in range(m)
            ],
            axis=1,
        )

        return square

    transition_inds = jnp.array(
        [[repeated_indices(i_basis, j_basis) for j_basis in j_s] for i_basis in i_s]
    )
    return transition_inds

compute_transition_amplitudes(unitary: jnp.array, transition_inds: jnp.array)

Calculates all i -> j transition amplitudes in a jit-able manner.

Source code in src/squint/ops/math.py
@jax.jit
def compute_transition_amplitudes(unitary: jnp.array, transition_inds: jnp.array):
    """Calculates all i -> j transition amplitudes in a jit-able manner."""
    a_ijs = unitary[transition_inds[:, :, 0, :, :], transition_inds[:, :, 1, :, :]]

    # swapping axes required when using recursive permanent function
    a_ijs_swapaxes = einops.rearrange(a_ijs, "i o a b -> a b i o")
    transition_amplitudes = permanent(a_ijs_swapaxes)  # fastest after jit of the three

    return transition_amplitudes

The code for the gellman function is adapted from the PySME project, which is licensed under the MIT license.

Source: https://pysme.readthedocs.io/en/latest/_modules/gellmann.html .. module:: gellmann.py :synopsis: Generate generalized Gell-Mann matrices .. moduleauthor:: Jonathan Gross jarthurgross@gmail.com

Functions to generate the generalized Pauli (i.e., Gell-Mann matrices)

gellmann(j, k, d)

Returns a generalized Gell-Mann matrix of dimension d. According to the convention in Bloch Vectors for Qubits by Bertlmann and Krammer (2008), returns :math:\Lambda^j for :math:1\leq j=k\leq d-1, :math:\Lambda^{kj}_s for :math:1\leq k<j\leq d, :math:\Lambda^{jk}_a for :math:1\leq j<k\leq d, and :math:I for :math:j=k=d.

:param j: First index for generalized Gell-Mann matrix :type j: positive integer :param k: Second index for generalized Gell-Mann matrix :type k: positive integer :param d: Dimension of the generalized Gell-Mann matrix :type d: positive integer :returns: A genereralized Gell-Mann matrix. :rtype: numpy.array

Source code in src/squint/ops/gellmann.py
def gellmann(j, k, d):
    r"""Returns a generalized Gell-Mann matrix of dimension d. According to the
    convention in *Bloch Vectors for Qubits* by Bertlmann and Krammer (2008),
    returns :math:`\Lambda^j` for :math:`1\leq j=k\leq d-1`,
    :math:`\Lambda^{kj}_s` for :math:`1\leq k<j\leq d`,
    :math:`\Lambda^{jk}_a` for :math:`1\leq j<k\leq d`, and
    :math:`I` for :math:`j=k=d`.

    :param j: First index for generalized Gell-Mann matrix
    :type j:  positive integer
    :param k: Second index for generalized Gell-Mann matrix
    :type k:  positive integer
    :param d: Dimension of the generalized Gell-Mann matrix
    :type d:  positive integer
    :returns: A genereralized Gell-Mann matrix.
    :rtype:   numpy.array

    """

    if j > k:
        gjkd = jnp.zeros((d, d), dtype=jnp.complex64)
        gjkd = gjkd.at[j - 1, k - 1].set(1.0)
        gjkd = gjkd.at[k - 1, j - 1].set(1.0)
    elif k > j:
        gjkd = jnp.zeros((d, d), dtype=jnp.complex64)
        gjkd = gjkd.at[j - 1, k - 1].set(-1.0j)
        gjkd = gjkd.at[k - 1, j - 1].set(1.0j)
    elif j == k and j < d:
        gjkd = jnp.sqrt(2 / (j * (j + 1))) * jnp.diag(
            jnp.array(
                [
                    1 + 0.0j if n <= j else (-j + 0.0j if n == (j + 1) else 0 + 0.0j)
                    for n in range(1, d + 1)
                ],
                dtype=jnp.complex64,
            )
        )
    else:
        gjkd = jnp.diag(
            jnp.array([1 + 0.0j for n in range(1, d + 1)], dtype=jnp.complex64)
        )

    return gjkd