Skip to content

Utilities

extract_paths(obj: PyTree, path='', op_type=None)

Recursively extract paths to non-None leaves in a PyTree, including their operation type.

Source code in src/squint/utils/partition.py
def extract_paths(obj: PyTree, path="", op_type=None):
    """
    Recursively extract paths to non-None leaves in a PyTree, including their operation type.
    """
    if isinstance(obj, eqx.Module):
        op_type = type(
            obj
        ).__name__  # Capture the operation type (e.g., "BeamSplitter")
        for field_name in obj.__dataclass_fields__:  # Traverse dataclass fields
            field_value = getattr(obj, field_name)
            yield from extract_paths(
                field_value, f"{path}.{field_name}" if path else field_name, op_type
            )
    elif isinstance(obj, dict):
        for k, v in obj.items():
            yield from extract_paths(v, f"{path}.{k}" if path else str(k), op_type)
    elif isinstance(obj, (list, tuple)):
        for i, v in enumerate(obj):
            yield from extract_paths(v, f"{path}[{i}]" if path else f"[{i}]", op_type)
    else:
        if obj is not None:
            yield path, op_type, obj  # Always return op_type, even if it's unchanged

partition_by_branches(pytree, branches_to_param)

Partition a PyTree into parameters and static parts based on specified branch nodes.

All leaves that are descendants of any branch in branches_to_param will be treated as parameters; the rest will be considered static.

Parameters:

Name Type Description Default
pytree PyTree

The input PyTree.

required
branches_to_param list

A list of subtree objects whose leaves should be treated as parameters.

required

Returns:

Type Description

(params_pytree, static_pytree)

Source code in src/squint/utils/partition.py
def partition_by_branches(pytree, branches_to_param):
    """
    Partition a PyTree into parameters and static parts based on specified branch nodes.

    All leaves that are descendants of any branch in `branches_to_param` will be treated
    as parameters; the rest will be considered static.

    Args:
        pytree (PyTree): The input PyTree.
        branches_to_param (list): A list of subtree objects whose leaves should be treated as parameters.

    Returns:
        (params_pytree, static_pytree)
    """
    set(map(id, branches_to_param))

    def is_param(leaf):
        # Check whether this leaf is a descendant of any specified branch
        for branch in branches_to_param:
            branch_leaves, _ = jax.tree_util.tree_flatten(branch)
            if any(id(leaf) == id(bl) for bl in branch_leaves):
                return True
        return False

    return eqx.partition(pytree, is_param)

partition_by_leaves(pytree, leaves_to_param)

Partition a PyTree into parameters and static parts based on specified leaves. Args: pytree (PyTree): The input PyTree containing parameters and static parts. leaves_to_param (list): A list of leaves that should be treated as parameters. Returns: tuple: A tuple containing two PyTrees: the parameters and the static parts.

Example

import equinox as eqx leaves = [pytree.ops['phase'].phi, pytree.ops['phase'].epsilon] params, static = partition_by_leaves(pytree, leaves)

Source code in src/squint/utils/partition.py
def partition_by_leaves(pytree, leaves_to_param):
    """
    Partition a PyTree into parameters and static parts based on specified leaves.
    Args:
        pytree (PyTree): The input PyTree containing parameters and static parts.
        leaves_to_param (list): A list of leaves that should be treated as parameters.
    Returns:
        tuple: A tuple containing two PyTrees: the parameters and the static parts.

    Example:
        >>> import equinox as eqx
        >>> leaves = [pytree.ops['phase'].phi, pytree.ops['phase'].epsilon]
        >>> params, static = partition_by_leaves(pytree, leaves)
    """
    leaves_set = set(
        map(id, leaves_to_param)
    )  # use `id()` to compare by object identity
    is_param = lambda leaf: id(leaf) in leaves_set
    return eqx.partition(pytree, is_param)

partition_op(pytree: PyTree, name: Union[str, Sequence[str]])

Partition a PyTree into parameters and static parts based on the operation name key. Args: pytree (PyTree): The input PyTree containing operations. name (str): The operation name key to filter by.

Source code in src/squint/utils/partition.py
def partition_op(
    pytree: PyTree, name: Union[str, Sequence[str]]
):  # TODO: allow multiple names
    """
    Partition a PyTree into parameters and static parts based on the operation name key.
    Args:
        pytree (PyTree): The input PyTree containing operations.
        name (str): The operation name key to filter by.
    """
    # if isinstance(names, str):
    # names = [names]

    def select(pytree: PyTree, name: str):
        """Sets all leaves to `True` for a given op key from the given Pytree)"""
        get_leaf = lambda t: t.ops[name]
        null = jax.tree_util.tree_map(lambda _: True, pytree.ops[name])
        return eqx.tree_at(get_leaf, pytree, null)

    def mask(val: str, mask1, mask2):
        """Logical AND mask over Pytree"""
        if isinstance(mask1, bool) and isinstance(mask2, bool):
            if mask1 and mask2:
                return True
        else:
            return False

    _params = eqx.filter(pytree, eqx.is_inexact_array, inverse=True, replace=True)
    # _ops = [select(pytree, name) for name in names]
    _op = select(pytree, name)

    # filter = jax.tree_util.tree_map(mask, pytree, _params, *_ops)
    filter = jax.tree_util.tree_map(mask, pytree, _params, _op)

    params, static = eqx.partition(pytree, filter_spec=filter)

    return params, static

print_nonzero_entries(arr)

Print the indices and values of non-zero entries in a JAX array. Args: arr (jnp.ndarray): The JAX array to inspect.

Source code in src/squint/utils/__init__.py
def print_nonzero_entries(arr):
    """
    Print the indices and values of non-zero entries in a JAX array.
    Args:
        arr (jnp.ndarray): The JAX array to inspect.
    """
    nonzero_indices = jnp.array(jnp.nonzero(arr)).T
    nonzero_values = arr[tuple(nonzero_indices.T)]
    for idx, value in zip(nonzero_indices, nonzero_values, strict=True):
        print(f"Basis: {jnp.array(idx)}, Value: {value}")

hdfdict

This code is adapted from, https://github.com/SiggiGue/hdfdict, which is licensed under MIT permissions.

LazyHdfDict

Bases: UserDict

Helps loading data only if values from the dict are requested.

This is done by reimplementing the getitem method.

Source code in src/squint/utils/hdfdict.py
class LazyHdfDict(UserDict):
    """Helps loading data only if values from the dict are requested.

    This is done by reimplementing the __getitem__ method.

    """

    def __init__(self, _h5file=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._h5file = _h5file  # used to close the file on deletion.

    def __getitem__(self, key):
        """Returns item and loads dataset if needed."""
        item = super().__getitem__(key)
        if isinstance(item, h5py.Dataset):
            item = unpack_dataset(item)
            self.__setitem__(key, item)
        return item

    def unlazy(self):
        """Unpacks all datasets.
        You can call dict(this_instance) then to get a real dict.
        """
        load(self, lazy=False)

    def close(self):
        """Closes the h5file if provided at initialization."""
        if self._h5file and hasattr(self._h5file, "close"):
            self._h5file.close()

    def __del__(self):
        self.close()

    def _ipython_key_completions_(self):
        """Returns a tuple of keys.
        Special Method for ipython to get key completion
        """
        return tuple(self.keys())
__getitem__(key)

Returns item and loads dataset if needed.

Source code in src/squint/utils/hdfdict.py
def __getitem__(self, key):
    """Returns item and loads dataset if needed."""
    item = super().__getitem__(key)
    if isinstance(item, h5py.Dataset):
        item = unpack_dataset(item)
        self.__setitem__(key, item)
    return item
unlazy()

Unpacks all datasets. You can call dict(this_instance) then to get a real dict.

Source code in src/squint/utils/hdfdict.py
def unlazy(self):
    """Unpacks all datasets.
    You can call dict(this_instance) then to get a real dict.
    """
    load(self, lazy=False)
close()

Closes the h5file if provided at initialization.

Source code in src/squint/utils/hdfdict.py
def close(self):
    """Closes the h5file if provided at initialization."""
    if self._h5file and hasattr(self._h5file, "close"):
        self._h5file.close()

hdf_file(hdf, lazy=True, *args, **kwargs)

Context manager that yields an h5 file if hdf is a string, otherwise it yields hdf as is.

Source code in src/squint/utils/hdfdict.py
@contextmanager
def hdf_file(hdf, lazy=True, *args, **kwargs):
    """Context manager that yields an h5 file if `hdf` is a string,
    otherwise it yields hdf as is."""
    if isinstance(hdf, (str, Path)):
        if not lazy:
            # The file can be closed after reading
            # therefore the context manager is used.
            with h5py.File(hdf, *args, **kwargs) as hdf:
                yield hdf
        else:
            # The file should stay open because datasets
            # are read on item access.
            yield h5py.File(hdf, *args, **kwargs)
    else:
        yield hdf

unpack_dataset(item)

Reconstruct a hdfdict dataset. Only some special unpacking for yaml and datetime types.

Parameters

item : h5py.Dataset

Returns

key: Unpacked key value : Unpacked Data

Source code in src/squint/utils/hdfdict.py
def unpack_dataset(item):
    """Reconstruct a hdfdict dataset.
    Only some special unpacking for yaml and datetime types.

    Parameters
    ----------
    item : h5py.Dataset

    Returns
    -------
    key: Unpacked key
    value : Unpacked Data

    """
    value = item[()]
    type_id = item.attrs.get(TYPE, str_()).astype(str)
    if type_id == "datetime":
        if hasattr(value, "__iter__"):
            value = [datetime.fromtimestamp(ts) for ts in value]
        else:
            value = datetime.fromtimestamp(value)

    elif type_id == "yaml":
        value = yaml.safe_load(value.decode())

    elif type_id == "list":
        value = list(value)

    elif type_id == "tuple":
        value = tuple(value)

    elif type_id == "str":
        value = str_(value).astype(str)

    return value

load(hdf, lazy=True, unpacker=unpack_dataset, mode='r', *args, **kwargs)

Returns a dictionary containing the groups as keys and the datasets as values from given hdf file.

Parameters

hdf : string (path to file) or h5py.File() or h5py.Group() lazy : bool If True, the datasets are lazy loaded at the moment an item is requested. upacker : callable Unpack function gets value of type h5py.Dataset. Must return the data you would like to have it in the returned dict. mode : str File read mode. Default: 'r'.

Returns

d : dict The dictionary containing all groupnames as keys and datasets as values, with group/file attributes under 'attrs'.

Source code in src/squint/utils/hdfdict.py
def load(hdf, lazy=True, unpacker=unpack_dataset, mode="r", *args, **kwargs):
    """Returns a dictionary containing the
    groups as keys and the datasets as values
    from given hdf file.

    Parameters
    ----------
    hdf : string (path to file) or `h5py.File()` or `h5py.Group()`
    lazy : bool
        If True, the datasets are lazy loaded at the moment an item is requested.
    upacker : callable
        Unpack function gets `value` of type h5py.Dataset.
        Must return the data you would like to have it in the returned dict.
    mode : str
        File read mode. Default: 'r'.

    Returns
    -------
    d : dict
        The dictionary containing all groupnames as keys and
        datasets as values, with group/file attributes under 'attrs'.
    """

    def _decode_attr(v):
        # Decode bytes to str if needed
        if isinstance(v, bytes):
            try:
                return v.decode("utf-8")
            except Exception:
                return v
        return v

    def _recurse(hdfobject, datadict):
        # Extract attributes
        attrs = {}
        for k, v in hdfobject.attrs.items():
            attrs[k] = _decode_attr(v)
        if attrs:
            datadict["attrs"] = attrs

        for key, value in hdfobject.items():
            if isinstance(value, h5py.Group):
                if lazy:
                    datadict[key] = LazyHdfDict()
                else:
                    datadict[key] = {}
                datadict[key] = _recurse(value, datadict[key])
            elif isinstance(value, h5py.Dataset):
                if not lazy:
                    value = unpacker(value)
                datadict[key] = value

        return datadict

    with hdf_file(hdf, lazy=lazy, mode=mode, *args, **kwargs) as hdf:
        if lazy:
            data = LazyHdfDict(_h5file=hdf)
        else:
            data = {}
        return _recurse(hdf, data)

pack_dataset(hdfobject, key, value)

Packs a given key value pair into a dataset in the given hdfobject.

Source code in src/squint/utils/hdfdict.py
def pack_dataset(hdfobject, key, value):
    """Packs a given key value pair into a dataset in the given hdfobject."""

    isdt = None
    if isinstance(value, datetime):
        value = value.timestamp()
        isdt = True

    if hasattr(value, "__iter__"):
        if all(isinstance(i, datetime) for i in value):
            value = [item.timestamp() for item in value]
            isdt = True

    try:
        ds = hdfobject.create_dataset(name=key, data=value)

        if isdt:
            attr_data = "datetime"
        elif isinstance(value, list):
            attr_data = "list"
        elif isinstance(value, tuple):
            attr_data = "tuple"
        elif isinstance(value, str):
            attr_data = "str"
            value = value.encode("utf-8")  # encode string to bytes
        else:
            attr_data = None

        if attr_data:
            ds.attrs.create(name=TYPE, data=str_(attr_data))

    except (TypeError, ValueError):
        # Obviously the data was not serializable. To give it
        # a last try; serialize it to yaml
        # and save it to the hdf file:
        ds = hdfobject.create_dataset(name=key, data=str_(yaml.safe_dump(value)))

        ds.attrs.create(name=TYPE, data=str_("yaml"))

dump(data, hdf, packer=pack_dataset, mode='w', *args, **kwargs)

Adds keys of given dict as groups and values as datasets to the given hdf-file (by string or object) or group object.

Parameters

data : dict The dictionary containing only string keys and data values or dicts again. hdf : string (path to file) or h5py.File() or h5py.Group() packer : callable Callable gets hdfobject, key, value as input. hdfobject is considered to be either a h5py.File or a h5py.Group. key is the name of the dataset. value is the dataset to be packed and accepted by h5py. mode : str File write mode. Default: 'w'

Returns

hdf : obj h5py.Group() or h5py.File() instance

Source code in src/squint/utils/hdfdict.py
def dump(data, hdf, packer=pack_dataset, mode="w", *args, **kwargs):
    """Adds keys of given dict as groups and values as datasets
    to the given hdf-file (by string or object) or group object.

    Parameters
    ----------
    data : dict
        The dictionary containing only string keys and
        data values or dicts again.
    hdf : string (path to file) or `h5py.File()` or `h5py.Group()`
    packer : callable
        Callable gets `hdfobject, key, value` as input.
        `hdfobject` is considered to be either a h5py.File or a h5py.Group.
        `key` is the name of the dataset.
        `value` is the dataset to be packed and accepted by h5py.
    mode : str
        File write mode. Default: 'w'

    Returns
    -------
    hdf : obj
        `h5py.Group()` or `h5py.File()` instance

    """

    def _recurse(datadict, hdfobject):
        # Handle attributes if present
        attrs = datadict.pop("attrs", None) if isinstance(datadict, dict) else None
        if attrs is not None:
            for k, v in attrs.items():
                if isinstance(v, str):
                    hdfobject.attrs[k] = v.encode("utf-8")
                elif isinstance(v, (int, float)):
                    hdfobject.attrs[k] = v
                else:
                    raise TypeError(
                        f"Attribute '{k}' must be str, int, or float, not {type(v)}"
                    )
        # Handle groups and datasets
        for key, value in datadict.items():
            if isinstance(value, (dict, LazyHdfDict)):
                hdfgroup = hdfobject.create_group(key)
                _recurse(value, hdfgroup)
            else:
                packer(hdfobject, key, value)

    with hdf_file(hdf, mode=mode, *args, **kwargs) as hdf:
        _recurse(data, hdf)

io

IO

The IO class encapsulates all saving/loading features of data, figures, etc. This provides consistent filetypes, naming conventions, etc.

Attributes:

Name Type Description
default_path Path

The default path where the data is stored.

path Path

The path where the data is stored.

verbose bool

A flag indicating whether to print out the path of each saved/loaded file.

Typical usage

io = IO(path=r"\path\to\data") io.load_txt(filename="filename.txt")

or io = IO.create_new_save_folder(folder="subfolder", include_date=True, include_uuid=True) io.save_df(df, filename="dataframe.txt")

Source code in src/squint/utils/io.py
class IO:
    r"""
    The IO class encapsulates all saving/loading features of data, figures, etc.
    This provides consistent filetypes, naming conventions, etc.

    Attributes:
        default_path (pathlib.Path): The default path where the data is stored.
        path (pathlib.Path): The path where the data is stored.
        verbose (bool): A flag indicating whether to print out the path of each saved/loaded file.

    Typical usage:
        io = IO(path=r"\path\to\data")
        io.load_txt(filename="filename.txt")

    or
        io = IO.create_new_save_folder(folder="subfolder", include_date=True, include_uuid=True)
        io.save_df(df, filename="dataframe.txt")
    """

    default_path = os.getenv(
        "DATA_PATH", pathlib.Path(__file__).parent.parent.joinpath("data")
    )

    def __init__(
        self,
        path=None,
        folder="",
        include_date=False,
        include_time=False,
        include_id=False,
        verbose=True,
    ):
        if path is None:
            path = self.default_path

        if type(path) is str:
            path = pathlib.Path(path)

        date = datetime.date.today().isoformat()
        time = datetime.datetime.now().strftime("%H-%M-%S")
        if not folder:  # if empty string
            warnings.warn(
                "No folder entered. Saving to a folder with a unique identifier",
                stacklevel=2,
            )
            include_date, include_id, verbose = True, True, True

        # build the full folder name with date, time, and uuid, if selected
        _str = ""
        if include_date:
            _str = _str + date + "_"
        if include_time:
            _str = _str + time + "_"

        _str = _str + folder

        if include_id:
            _str = (
                _str + "_" + "".join(random.choice(string.hexdigits) for _ in range(4))
            )

        self.path = path.joinpath(_str)
        self.verbose = verbose
        return

    def subpath(self, subfolder: str):
        cls = copy.deepcopy(self)
        cls.path = cls.path.joinpath(subfolder)
        return cls

    def save_json(self, variable, filename):
        """
        Save serialized python object into a json format, at filename

        Args:
            variable: The object to save.
            filename (str): Name of the file to which variable should be saved.

        Returns:
            None
        """
        full_path = self.path.joinpath(filename)
        os.makedirs(full_path.parent, exist_ok=True)
        self._save_json(variable, full_path)
        if self.verbose:
            print(f"{current_time()} | Saved to {full_path} successfully.")

    def load_json(self, filename):
        """
        Load serialized python object from json.

        Args:
            filename (str): Name of the file from which we are loading the object.

        Returns:
            The loaded object data.
        """
        full_path = self.path.joinpath(filename)
        file = self._load_json(full_path)
        if self.verbose:
            print(f"{current_time()} | Loaded from {full_path} successfully.")
        return file

    def save_txt(self, variable, filename):
        """
        Save serialized python object into a text format, at filename.

        Args:
            variable: The object to save.
            filename (str): Name of the file to which variable should be saved.

        Returns:
            None
        """
        full_path = self.path.joinpath(filename)
        os.makedirs(full_path.parent, exist_ok=True)
        self._save_txt(variable, full_path)
        if self.verbose:
            print(f"{current_time()} | Saved to {full_path} successfully.")

    def load_txt(self, filename):
        """
        Load serialized python object from text file.

        Args:
            filename (str): Name of the file from which we are loading the object.

        Returns:
            The loaded object data.
        """
        full_path = self.path.joinpath(filename)
        file = self._load_txt(full_path)
        if self.verbose:
            print(f"{current_time()} | Loaded from {full_path} successfully.")
        return file

    def save_dataframe(self, df, filename):
        """
        Save a panda dataframe object to csv.

        Args:
            df (pandas.DataFrame): Data contained in a dataframe.
            filename (str): File to which data should be saved.

        Returns:
            None
        """
        ext = ".pkl"
        full_path = self.path.joinpath(filename + ext)
        os.makedirs(full_path.parent, exist_ok=True)
        # df.to_csv(str(full_path), sep=",", index=False, header=True)
        df.to_pickle(str(full_path))
        if self.verbose:
            print(f"{current_time()} | Saved to {full_path} successfully.")

    def load_dataframe(self, filename):
        """
        Load panda dataframe object from CSV.

        Args:
            filename (str): Name of the file from which data should be loaded.

        Returns:
            pandas.DataFrame: Dataframe data.
        """
        import pandas as pd

        ext = ".pkl"
        full_path = self.path.joinpath(filename + ext)
        # df = pd.read_csv(str(full_path), sep=",", header=0)
        df = pd.read_pickle(str(full_path))
        if self.verbose:
            print(f"{current_time()} | Loaded from {full_path} successfully.")
        return df

    def save_yaml(self, data, filename):
        """
        Save dictionary to YAML file.

        Args:
            filename (str): Name of the file from which data should be saved.

        """

        full_path = self.path.joinpath(filename)
        os.makedirs(full_path.parent, exist_ok=True)
        with open(full_path, "w") as fid:
            _data = asdict(data)
            # with open(file, "w") as fid:
            yaml.dump(_data, fid)

        if self.verbose:
            print(f"{current_time()} | Saved to {full_path} successfully.")

    def save_figure(self, fig, filename):
        """
        Save a figure (image datatype can be specified as part of filename).

        Args:
            fig (matplotlib.figure.Figure): The figure containing the figure to save.
            filename (str): The filename to which we save a figure.

        Returns:
            None
        """
        full_path = self.path.joinpath(filename)
        os.makedirs(full_path.parent, exist_ok=True)
        fig.savefig(full_path, dpi=300, bbox_inches="tight")
        if self.verbose:
            print(f"{current_time()} | Saved figure to {full_path} successfully.")

    def save_np_array(self, np_arr, filename):
        """
        Save numpy array to a text document.

        Args:
            np_arr (numpy.array): The array which we are saving.
            filename (str): Name of the text file to which we want to save the numpy array.

        Returns:
            None
        """
        import numpy as np

        full_path = self.path.joinpath(filename)
        os.makedirs(full_path.parent, exist_ok=True)
        np.savetxt(str(full_path), np_arr)
        if self.verbose:
            print(f"{current_time()} | Saved to {full_path} successfully.")

    def load_np_array(self, filename, complex_vals=False):
        """
        Loads numpy array from a text document.

        Args:
            filename (str): Name of the text file from which we want to load the numpy array.
            complex_vals (bool): True if we expect the numpy array to be complex, False otherwise.

        Returns:
            numpy.array: The loaded numpy array.
        """
        import numpy as np

        full_path = self.path.joinpath(filename)
        file = np.loadtxt(
            str(full_path), dtype=np.complex if complex_vals else np.float
        )
        if self.verbose:
            print(f"{current_time()} | Loaded from {full_path} successfully.")
        return file

    def save_csv(self, df, filename):
        """
        Save a panda dataframe object to csv.

        Args:
            df (pandas.DataFrame): Data contained in a dataframe.
            filename (str): File to which data should be saved.

        Returns:
            None
        """
        ext = ".csv"
        full_path = self.path.joinpath(filename + ext)
        os.makedirs(full_path.parent, exist_ok=True)
        df.to_csv(str(full_path), sep=",", index=False, header=True)
        if self.verbose:
            print(f"{current_time()} | Saved to {full_path} successfully.")

    def load_csv(self, filename):
        """
        Load panda dataframe object from CSV.

        Args:
            filename (str): Name of the file from which data should be loaded.

        Returns:
            pandas.DataFrame: Dataframe data.
        """
        import pandas as pd

        full_path = self.path.joinpath(filename)
        df = pd.read_csv(str(full_path), sep=",", header=0)
        if self.verbose:
            print(f"{current_time()} | Loaded from {full_path} successfully.")
        return df

    def save_h5(self, filename):
        """
        Initialize an H5 file to save datasets into.

        Args:
            filename (str): Name of the file from which data should be saved.

        Returns:
            h5py.File: H5 file.
        """
        full_path = self.path.joinpath(filename)
        os.makedirs(full_path.parent, exist_ok=True)
        hf = h5py.File(full_path, "w")
        if self.verbose:
            print(f"{current_time()} | Saving HDF5 file at {full_path}.")
        return hf

    @staticmethod
    def _save_json(variable, path):
        """
        Helper method for saving to json files
        """
        with open(path, "w+") as json_file:
            json.dump(variable, json_file, indent=4)

    @staticmethod
    def _load_json(path):
        """
        Helper method for loading from json files
        """
        with open(path) as json_file:
            data = json.load(json_file)
        return data

    @staticmethod
    def _save_txt(variable, path):
        """
        Helper method for saving to text files
        """
        with open(path, "w") as txt_file:
            txt_file.write(variable)

    @staticmethod
    def _load_txt(path):
        """
        Helper method for loading from text files
        """
        # with open(path) as json_file:
        #     data = json.load(json_file)
        # return data
        with open(path) as txt_file:
            txt_str = txt_file.read()
        return txt_str
save_json(variable, filename)

Save serialized python object into a json format, at filename

Parameters:

Name Type Description Default
variable

The object to save.

required
filename str

Name of the file to which variable should be saved.

required

Returns:

Type Description

None

Source code in src/squint/utils/io.py
def save_json(self, variable, filename):
    """
    Save serialized python object into a json format, at filename

    Args:
        variable: The object to save.
        filename (str): Name of the file to which variable should be saved.

    Returns:
        None
    """
    full_path = self.path.joinpath(filename)
    os.makedirs(full_path.parent, exist_ok=True)
    self._save_json(variable, full_path)
    if self.verbose:
        print(f"{current_time()} | Saved to {full_path} successfully.")
load_json(filename)

Load serialized python object from json.

Parameters:

Name Type Description Default
filename str

Name of the file from which we are loading the object.

required

Returns:

Type Description

The loaded object data.

Source code in src/squint/utils/io.py
def load_json(self, filename):
    """
    Load serialized python object from json.

    Args:
        filename (str): Name of the file from which we are loading the object.

    Returns:
        The loaded object data.
    """
    full_path = self.path.joinpath(filename)
    file = self._load_json(full_path)
    if self.verbose:
        print(f"{current_time()} | Loaded from {full_path} successfully.")
    return file
save_txt(variable, filename)

Save serialized python object into a text format, at filename.

Parameters:

Name Type Description Default
variable

The object to save.

required
filename str

Name of the file to which variable should be saved.

required

Returns:

Type Description

None

Source code in src/squint/utils/io.py
def save_txt(self, variable, filename):
    """
    Save serialized python object into a text format, at filename.

    Args:
        variable: The object to save.
        filename (str): Name of the file to which variable should be saved.

    Returns:
        None
    """
    full_path = self.path.joinpath(filename)
    os.makedirs(full_path.parent, exist_ok=True)
    self._save_txt(variable, full_path)
    if self.verbose:
        print(f"{current_time()} | Saved to {full_path} successfully.")
load_txt(filename)

Load serialized python object from text file.

Parameters:

Name Type Description Default
filename str

Name of the file from which we are loading the object.

required

Returns:

Type Description

The loaded object data.

Source code in src/squint/utils/io.py
def load_txt(self, filename):
    """
    Load serialized python object from text file.

    Args:
        filename (str): Name of the file from which we are loading the object.

    Returns:
        The loaded object data.
    """
    full_path = self.path.joinpath(filename)
    file = self._load_txt(full_path)
    if self.verbose:
        print(f"{current_time()} | Loaded from {full_path} successfully.")
    return file
save_dataframe(df, filename)

Save a panda dataframe object to csv.

Parameters:

Name Type Description Default
df DataFrame

Data contained in a dataframe.

required
filename str

File to which data should be saved.

required

Returns:

Type Description

None

Source code in src/squint/utils/io.py
def save_dataframe(self, df, filename):
    """
    Save a panda dataframe object to csv.

    Args:
        df (pandas.DataFrame): Data contained in a dataframe.
        filename (str): File to which data should be saved.

    Returns:
        None
    """
    ext = ".pkl"
    full_path = self.path.joinpath(filename + ext)
    os.makedirs(full_path.parent, exist_ok=True)
    # df.to_csv(str(full_path), sep=",", index=False, header=True)
    df.to_pickle(str(full_path))
    if self.verbose:
        print(f"{current_time()} | Saved to {full_path} successfully.")
load_dataframe(filename)

Load panda dataframe object from CSV.

Parameters:

Name Type Description Default
filename str

Name of the file from which data should be loaded.

required

Returns:

Type Description

pandas.DataFrame: Dataframe data.

Source code in src/squint/utils/io.py
def load_dataframe(self, filename):
    """
    Load panda dataframe object from CSV.

    Args:
        filename (str): Name of the file from which data should be loaded.

    Returns:
        pandas.DataFrame: Dataframe data.
    """
    import pandas as pd

    ext = ".pkl"
    full_path = self.path.joinpath(filename + ext)
    # df = pd.read_csv(str(full_path), sep=",", header=0)
    df = pd.read_pickle(str(full_path))
    if self.verbose:
        print(f"{current_time()} | Loaded from {full_path} successfully.")
    return df
save_yaml(data, filename)

Save dictionary to YAML file.

Parameters:

Name Type Description Default
filename str

Name of the file from which data should be saved.

required
Source code in src/squint/utils/io.py
def save_yaml(self, data, filename):
    """
    Save dictionary to YAML file.

    Args:
        filename (str): Name of the file from which data should be saved.

    """

    full_path = self.path.joinpath(filename)
    os.makedirs(full_path.parent, exist_ok=True)
    with open(full_path, "w") as fid:
        _data = asdict(data)
        # with open(file, "w") as fid:
        yaml.dump(_data, fid)

    if self.verbose:
        print(f"{current_time()} | Saved to {full_path} successfully.")
save_figure(fig, filename)

Save a figure (image datatype can be specified as part of filename).

Parameters:

Name Type Description Default
fig Figure

The figure containing the figure to save.

required
filename str

The filename to which we save a figure.

required

Returns:

Type Description

None

Source code in src/squint/utils/io.py
def save_figure(self, fig, filename):
    """
    Save a figure (image datatype can be specified as part of filename).

    Args:
        fig (matplotlib.figure.Figure): The figure containing the figure to save.
        filename (str): The filename to which we save a figure.

    Returns:
        None
    """
    full_path = self.path.joinpath(filename)
    os.makedirs(full_path.parent, exist_ok=True)
    fig.savefig(full_path, dpi=300, bbox_inches="tight")
    if self.verbose:
        print(f"{current_time()} | Saved figure to {full_path} successfully.")
save_np_array(np_arr, filename)

Save numpy array to a text document.

Parameters:

Name Type Description Default
np_arr array

The array which we are saving.

required
filename str

Name of the text file to which we want to save the numpy array.

required

Returns:

Type Description

None

Source code in src/squint/utils/io.py
def save_np_array(self, np_arr, filename):
    """
    Save numpy array to a text document.

    Args:
        np_arr (numpy.array): The array which we are saving.
        filename (str): Name of the text file to which we want to save the numpy array.

    Returns:
        None
    """
    import numpy as np

    full_path = self.path.joinpath(filename)
    os.makedirs(full_path.parent, exist_ok=True)
    np.savetxt(str(full_path), np_arr)
    if self.verbose:
        print(f"{current_time()} | Saved to {full_path} successfully.")
load_np_array(filename, complex_vals=False)

Loads numpy array from a text document.

Parameters:

Name Type Description Default
filename str

Name of the text file from which we want to load the numpy array.

required
complex_vals bool

True if we expect the numpy array to be complex, False otherwise.

False

Returns:

Type Description

numpy.array: The loaded numpy array.

Source code in src/squint/utils/io.py
def load_np_array(self, filename, complex_vals=False):
    """
    Loads numpy array from a text document.

    Args:
        filename (str): Name of the text file from which we want to load the numpy array.
        complex_vals (bool): True if we expect the numpy array to be complex, False otherwise.

    Returns:
        numpy.array: The loaded numpy array.
    """
    import numpy as np

    full_path = self.path.joinpath(filename)
    file = np.loadtxt(
        str(full_path), dtype=np.complex if complex_vals else np.float
    )
    if self.verbose:
        print(f"{current_time()} | Loaded from {full_path} successfully.")
    return file
save_csv(df, filename)

Save a panda dataframe object to csv.

Parameters:

Name Type Description Default
df DataFrame

Data contained in a dataframe.

required
filename str

File to which data should be saved.

required

Returns:

Type Description

None

Source code in src/squint/utils/io.py
def save_csv(self, df, filename):
    """
    Save a panda dataframe object to csv.

    Args:
        df (pandas.DataFrame): Data contained in a dataframe.
        filename (str): File to which data should be saved.

    Returns:
        None
    """
    ext = ".csv"
    full_path = self.path.joinpath(filename + ext)
    os.makedirs(full_path.parent, exist_ok=True)
    df.to_csv(str(full_path), sep=",", index=False, header=True)
    if self.verbose:
        print(f"{current_time()} | Saved to {full_path} successfully.")
load_csv(filename)

Load panda dataframe object from CSV.

Parameters:

Name Type Description Default
filename str

Name of the file from which data should be loaded.

required

Returns:

Type Description

pandas.DataFrame: Dataframe data.

Source code in src/squint/utils/io.py
def load_csv(self, filename):
    """
    Load panda dataframe object from CSV.

    Args:
        filename (str): Name of the file from which data should be loaded.

    Returns:
        pandas.DataFrame: Dataframe data.
    """
    import pandas as pd

    full_path = self.path.joinpath(filename)
    df = pd.read_csv(str(full_path), sep=",", header=0)
    if self.verbose:
        print(f"{current_time()} | Loaded from {full_path} successfully.")
    return df
save_h5(filename)

Initialize an H5 file to save datasets into.

Parameters:

Name Type Description Default
filename str

Name of the file from which data should be saved.

required

Returns:

Type Description

h5py.File: H5 file.

Source code in src/squint/utils/io.py
def save_h5(self, filename):
    """
    Initialize an H5 file to save datasets into.

    Args:
        filename (str): Name of the file from which data should be saved.

    Returns:
        h5py.File: H5 file.
    """
    full_path = self.path.joinpath(filename)
    os.makedirs(full_path.parent, exist_ok=True)
    hf = h5py.File(full_path, "w")
    if self.verbose:
        print(f"{current_time()} | Saving HDF5 file at {full_path}.")
    return hf

current_time()

Returns the current date and time in a consistent format.

This function is used for monitoring long-running measurements by providing the current date and time in the "%d/%m/%Y, %H:%M:%S" format.

Returns:

Name Type Description
str

The current date and time as a string in the "%d/%m/%Y, %H:%M:%S" format.

Source code in src/squint/utils/io.py
def current_time():
    """
    Returns the current date and time in a consistent format.

    This function is used for monitoring long-running measurements by providing the current date and time in the "%d/%m/%Y, %H:%M:%S" format.

    Returns:
        str: The current date and time as a string in the "%d/%m/%Y, %H:%M:%S" format.
    """
    return datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")

partition

partition_by_leaves(pytree, leaves_to_param)

Partition a PyTree into parameters and static parts based on specified leaves. Args: pytree (PyTree): The input PyTree containing parameters and static parts. leaves_to_param (list): A list of leaves that should be treated as parameters. Returns: tuple: A tuple containing two PyTrees: the parameters and the static parts.

Example

import equinox as eqx leaves = [pytree.ops['phase'].phi, pytree.ops['phase'].epsilon] params, static = partition_by_leaves(pytree, leaves)

Source code in src/squint/utils/partition.py
def partition_by_leaves(pytree, leaves_to_param):
    """
    Partition a PyTree into parameters and static parts based on specified leaves.
    Args:
        pytree (PyTree): The input PyTree containing parameters and static parts.
        leaves_to_param (list): A list of leaves that should be treated as parameters.
    Returns:
        tuple: A tuple containing two PyTrees: the parameters and the static parts.

    Example:
        >>> import equinox as eqx
        >>> leaves = [pytree.ops['phase'].phi, pytree.ops['phase'].epsilon]
        >>> params, static = partition_by_leaves(pytree, leaves)
    """
    leaves_set = set(
        map(id, leaves_to_param)
    )  # use `id()` to compare by object identity
    is_param = lambda leaf: id(leaf) in leaves_set
    return eqx.partition(pytree, is_param)

partition_by_branches(pytree, branches_to_param)

Partition a PyTree into parameters and static parts based on specified branch nodes.

All leaves that are descendants of any branch in branches_to_param will be treated as parameters; the rest will be considered static.

Parameters:

Name Type Description Default
pytree PyTree

The input PyTree.

required
branches_to_param list

A list of subtree objects whose leaves should be treated as parameters.

required

Returns:

Type Description

(params_pytree, static_pytree)

Source code in src/squint/utils/partition.py
def partition_by_branches(pytree, branches_to_param):
    """
    Partition a PyTree into parameters and static parts based on specified branch nodes.

    All leaves that are descendants of any branch in `branches_to_param` will be treated
    as parameters; the rest will be considered static.

    Args:
        pytree (PyTree): The input PyTree.
        branches_to_param (list): A list of subtree objects whose leaves should be treated as parameters.

    Returns:
        (params_pytree, static_pytree)
    """
    set(map(id, branches_to_param))

    def is_param(leaf):
        # Check whether this leaf is a descendant of any specified branch
        for branch in branches_to_param:
            branch_leaves, _ = jax.tree_util.tree_flatten(branch)
            if any(id(leaf) == id(bl) for bl in branch_leaves):
                return True
        return False

    return eqx.partition(pytree, is_param)

partition_op(pytree: PyTree, name: Union[str, Sequence[str]])

Partition a PyTree into parameters and static parts based on the operation name key. Args: pytree (PyTree): The input PyTree containing operations. name (str): The operation name key to filter by.

Source code in src/squint/utils/partition.py
def partition_op(
    pytree: PyTree, name: Union[str, Sequence[str]]
):  # TODO: allow multiple names
    """
    Partition a PyTree into parameters and static parts based on the operation name key.
    Args:
        pytree (PyTree): The input PyTree containing operations.
        name (str): The operation name key to filter by.
    """
    # if isinstance(names, str):
    # names = [names]

    def select(pytree: PyTree, name: str):
        """Sets all leaves to `True` for a given op key from the given Pytree)"""
        get_leaf = lambda t: t.ops[name]
        null = jax.tree_util.tree_map(lambda _: True, pytree.ops[name])
        return eqx.tree_at(get_leaf, pytree, null)

    def mask(val: str, mask1, mask2):
        """Logical AND mask over Pytree"""
        if isinstance(mask1, bool) and isinstance(mask2, bool):
            if mask1 and mask2:
                return True
        else:
            return False

    _params = eqx.filter(pytree, eqx.is_inexact_array, inverse=True, replace=True)
    # _ops = [select(pytree, name) for name in names]
    _op = select(pytree, name)

    # filter = jax.tree_util.tree_map(mask, pytree, _params, *_ops)
    filter = jax.tree_util.tree_map(mask, pytree, _params, _op)

    params, static = eqx.partition(pytree, filter_spec=filter)

    return params, static

extract_paths(obj: PyTree, path='', op_type=None)

Recursively extract paths to non-None leaves in a PyTree, including their operation type.

Source code in src/squint/utils/partition.py
def extract_paths(obj: PyTree, path="", op_type=None):
    """
    Recursively extract paths to non-None leaves in a PyTree, including their operation type.
    """
    if isinstance(obj, eqx.Module):
        op_type = type(
            obj
        ).__name__  # Capture the operation type (e.g., "BeamSplitter")
        for field_name in obj.__dataclass_fields__:  # Traverse dataclass fields
            field_value = getattr(obj, field_name)
            yield from extract_paths(
                field_value, f"{path}.{field_name}" if path else field_name, op_type
            )
    elif isinstance(obj, dict):
        for k, v in obj.items():
            yield from extract_paths(v, f"{path}.{k}" if path else str(k), op_type)
    elif isinstance(obj, (list, tuple)):
        for i, v in enumerate(obj):
            yield from extract_paths(v, f"{path}[{i}]" if path else f"[{i}]", op_type)
    else:
        if obj is not None:
            yield path, op_type, obj  # Always return op_type, even if it's unchanged