Skip to content
Snippets Groups Projects
TraceWin.py 58.7 KiB
Newer Older
Yngve Levinsen's avatar
Yngve Levinsen committed
            self.Nrun += o.Nrun
    def savetohdf(self, filename="Density.h5", group="TraceWin", force=False):
        """
        Saves data to HDF5
Yngve Levinsen's avatar
Yngve Levinsen committed
        import h5py
        import sys
        fout = h5py.File(filename, "a")
        if group in fout:
            if force:
                del fout[group]
            else:
                if sys.flags.debug:
Yngve Levinsen's avatar
Yngve Levinsen committed
                    print("Group {} already exist in {}".format(group, filename))
        group = fout.create_group(group)

        # header attributes..
        group.attrs["version"] = self.version
        group.attrs["year"] = self.year
        group.attrs["Nrun"] = self.Nrun
        group.attrs["vlong"] = self.vlong
Yngve Levinsen's avatar
Yngve Levinsen committed
        length = len(self.z)
Yngve Levinsen's avatar
Yngve Levinsen committed
        partran = sum(self.Np) > 0

        # one number per location
        arrays = ["z", "nelp", "ib", "Np", "Xouv", "Youv"]
        array_units = ["m", "", "mA", "", "m", "m"]
Yngve Levinsen's avatar
Yngve Levinsen committed
        if self.version >= 8:
            arrays += ["energy_accept", "phase_ouv_pos", "phase_ouv_neg"]
            array_units += ["eV", "deg", "deg"]
        if partran:
            arrays += [
                "lost2",
                "Minlost",
                "Maxlost",
                "powlost2",
                "Minpowlost",
                "Maxpowlost",
            ]
            array_units += ["", "", "", "W*w", "W", "W"]

        # 7 numbers per location..
        coordinates = ["moy", "moy2", "_max", "_min"]
        coordinate_units = ["m", "m*m", "m", "m"]
Yngve Levinsen's avatar
Yngve Levinsen committed
        if self.version >= 5 and partran:
            coordinates += ["rms_size", "rms_size2"]
            coordinate_units += ["m", "m*m"]
Yngve Levinsen's avatar
Yngve Levinsen committed
        if self.version >= 6 and partran:
            coordinates += ["min_pos_moy", "max_pos_moy"]
            coordinate_units += ["m", "m"]
Yngve Levinsen's avatar
Yngve Levinsen committed
        for val, unit in zip(arrays, array_units):
            data_set = group.create_dataset(val, (length,), dtype="f")
Yngve Levinsen's avatar
Yngve Levinsen committed
            data_set[...] = getattr(self, val)
                data_set.attrs["unit"] = unit
Yngve Levinsen's avatar
Yngve Levinsen committed
        for val, unit in zip(coordinates, coordinate_units):
            data_set = group.create_dataset(val, (length, 7), dtype="f")
Yngve Levinsen's avatar
Yngve Levinsen committed
            data_set[...] = getattr(self, val)
                data_set.attrs["unit"] = unit
Yngve Levinsen's avatar
Yngve Levinsen committed
        if self.version >= 7 and partran:
            # 3 numbers per location..
            emit_data = ["rms_emit", "rms_emit2"]
            emit_units = ["m*rad", "m*m*rad*rad"]
Yngve Levinsen's avatar
Yngve Levinsen committed
            for val, unit in zip(emit_data, emit_units):
                data_set = group.create_dataset(val, (length, 3), dtype="f")
Yngve Levinsen's avatar
Yngve Levinsen committed
                data_set[...] = getattr(self, val)
                    data_set.attrs["unit"] = unit
        if partran:
Yngve Levinsen's avatar
Yngve Levinsen committed
            # 1 numbers per location and per run..
            data = ["lost", "powlost"]
            units = ["", "W"]
Yngve Levinsen's avatar
Yngve Levinsen committed
            for val, unit in zip(data, units):
                data_set = group.create_dataset(val, (length, self.Nrun), dtype="f")
Yngve Levinsen's avatar
Yngve Levinsen committed
                data_set[...] = getattr(self, val)
                    data_set.attrs["unit"] = unit
Yngve Levinsen's avatar
Yngve Levinsen committed
        fout.close()
Yngve Levinsen's avatar
Yngve Levinsen committed
class density_file(density):
    def __init__(self, filename, envelope=None):
        print("Deprecated, use TraceWin.density() instead")
        super().__init__(filename, envelope)


class remote_data_merger:
    def __init__(self, base="."):
Yngve Levinsen's avatar
Yngve Levinsen committed
        self._base = base
        self._files = []
Yngve Levinsen's avatar
Yngve Levinsen committed
    def add_file(self, filepath):
        import os

        if os.path.exists(filepath):
Yngve Levinsen's avatar
Yngve Levinsen committed
            fname = filepath
Yngve Levinsen's avatar
Yngve Levinsen committed
            fullpath = os.path.join(self._base, filepath)
            if os.path.exists(fullpath):
Yngve Levinsen's avatar
Yngve Levinsen committed
                fname = fullpath
Yngve Levinsen's avatar
Yngve Levinsen committed
                raise ValueError("Could not find file " + filepath)
        if fname not in self._files:
            self._files.append(fname)

Yngve Levinsen's avatar
Yngve Levinsen committed
    def generate_partran_out(self, filename=None):
        Creates a string to be written to file
        each line is a list.

        If filename is given, writes directly to output file.

Yngve Levinsen's avatar
Yngve Levinsen committed

        import numpy as np

Yngve Levinsen's avatar
Yngve Levinsen committed
        h1 = []
        h2 = []
Yngve Levinsen's avatar
Yngve Levinsen committed

Yngve Levinsen's avatar
Yngve Levinsen committed
        d1 = []
        d2 = []
        d3 = []
Yngve Levinsen's avatar
Yngve Levinsen committed

        if self._files:
            for f in self._files:
                string = open(f, "r").read()
                split = string.split("$$$")
                if split[9] != "Data_Error":
                    raise ValueError("Magic problem, please complain to Yngve")

                thisdata = split[10].strip().split("\n")
Yngve Levinsen's avatar
Yngve Levinsen committed
                if not h1:
Yngve Levinsen's avatar
Yngve Levinsen committed
                    h1 = [thisdata[0] + " (std in paranthesis)"]
                    h2 = thisdata[2:10]
Yngve Levinsen's avatar
Yngve Levinsen committed
                d1.append(thisdata[1].split())
                d2.append(thisdata[10])
                d3.append(thisdata[11])

            # fix d1:
            for i in range(len(d1)):
                for j in range(len(d1[0])):
Yngve Levinsen's avatar
Yngve Levinsen committed
                    d1[i][j] = float(d1[i][j])
            d1 = np.array(d1)
            means = d1.mean(axis=0)
            stds = d1.std(axis=0)
            d1 = []
Yngve Levinsen's avatar
Yngve Levinsen committed
                if stds[i] / means[i] < 1e-10:
                    stds[i] = 0.0
Yngve Levinsen's avatar
Yngve Levinsen committed
                # some small std are removed..
Yngve Levinsen's avatar
Yngve Levinsen committed
                if stds[i] / means[i] > 1e-8:
                    d1.append("%f(%f)" % (means[i], stds[i]))
Yngve Levinsen's avatar
Yngve Levinsen committed
                else:  # error is 0
Yngve Levinsen's avatar
Yngve Levinsen committed
                    d1.append(str(means[i]))
            d1 = [" ".join(d1)]
Yngve Levinsen's avatar
Yngve Levinsen committed

            # create data:
Yngve Levinsen's avatar
Yngve Levinsen committed
            data = h1 + d1 + h2 + d2 + d3
                open(filename, "w").write("\n".join(data))

class envDiag:
    """
    Read ENV_diag1.dat file

    This contains e.g. the absolute phase at each diag

    For now we do not read in all info from the file,
    so feel free to request or add anything else you would like.


    def __init__(self, filename):
        self.filename = filename
        self.elements = {}
        # Needed to get an ordered dictionary:
        self._elementList = []
        self._readAsciiFile()
        self.units = {}
        self._setUnits()

    def _readAsciiFile(self):
        for line in open(self.filename, "r"):
            if lsp[0] == "DIAG":
                self.elements[int(lsp[2])] = {}
                self._elementList.append(int(lsp[2]))
                current = self.elements[int(lsp[2])]
                current["loc"] = float(lsp[4])
            elif lsp[0] == "Ibeam:":
                current["current"] = float(lsp[1])
            elif lsp[0] == "Positions":
                current["phase"] = float(lsp[5])
                current["energy"] = float(lsp[6])
            elif lsp[0] == "RMS":
                current["x_rms"] = float(lsp[4]) * 0.01
                current["y_rms"] = float(lsp[5]) * 0.01
                current["phase_rms"] = float(lsp[6])
                current["energy_rms"] = float(lsp[7])
            elif lsp[0] == "Emittances":
                current["emit_x"] = float(lsp[3])
                current["emit_y"] = float(lsp[4])
                current["emit_z"] = float(lsp[5])
            elif lsp[0] == "Emittances99":
                current["emit99_x"] = float(lsp[3])
                current["emit99_y"] = float(lsp[4])
                current["emit99_z"] = float(lsp[5])
            elif lsp[0] == "Twiss":
                if lsp[1] == "Alpha" and lsp[3] == "(XXp,":
                    current["alpha_x"] = float(lsp[6])
                    current["alpha_y"] = float(lsp[7])
                elif lsp[1] == "Alpha" and lsp[3] == "(ZDp/p)":
                    current["alpha_z"] = float(lsp[5])
                elif lsp[1] == "Beta":
                    current["beta_x"] = float(lsp[5])
                    current["beta_y"] = float(lsp[6])
                    current["beta_z"] = float(lsp[7])
        Set the units for each element in the element dictionary
        (empty string for all undefined)
        """
        for key in ["loc", "x_rms", "y_rms"]:
            self.units[key] = "m"
        for key in ["emit_x", "emit_y", "emit_z", "emit99_x", "emit99_y", "emit99_z"]:
            self.units[key] = "Pi.mm.mrad"
        for key in ["current"]:
            self.units[key] = "mA"
        for key in ["energy", "energy_rms"]:
            self.units[key] = "MeV"
        for key in ["phase", "phase_rms"]:
            self.units[key] = "deg"
        for key in ["beta_x", "beta_y", "beta_z"]:
            self.units[key] = "mm/mrad"

        for element in self.elements:
            for key in self.elements[element]:
                if key not in self.units:
                    self.units[key] = ""
        Make a pretty print of the content
        first = True
        rjust = 12
        for ekey in self._elementList:
            element = self.elements[ekey]

            # Print header if this is the first element..
            if first:
                keys = [key for key in element]
                print("#", end=" ")
                print("NUM".rjust(rjust), end=" ")
                    print(key.rjust(rjust), end=" ")
                print("#", end=" ")
                print("".rjust(rjust), end=" ")
                    print(self.units[key].rjust(rjust), end=" ")
            print("  " + str(ekey).rjust(rjust), end=" ")
            for key in keys:
                num = element[key]
                if isinstance(num, float):
                    strnum = "{:.5e}".format(num)
                print(strnum.rjust(rjust), end=" ")
            print()

    def getElement(self, elementId):
        Returns the element dictionary for the given ID
        return self.elements[elementId]

    def getElementAtLoc(self, location):
        Returns a list of elements at the location requested
        ret = []
        for key in self.elements:
            if abs(self.elements[key]["loc"] - location) < 1e-6:
                ret.append(self.elements[key])
        return ret

    def getParameterFromAll(self, parameter):
        Returns a list containing the given parameter from all DIAGS,
        ordered by the location of the DIAGs
        """
        if parameter == "NUM":
        ret = []
        for key in self._elementList:
            ret.append(self.elements[key][parameter])

        return ret
class envelope:
    """
    Read an envelope file
    Create one by saving envelope data plot
    to ascii

    Example::
        from ess import TraceWin
        from matplotlib import pyplot as plt
        data = TraceWin.envelope('envelope.txt')
        print(data.keys())
        for key in data:
            print(key, data.unit(key))
            if 'rms_' in key:
                plt.plot(data['position'], data[key]/max(data[key]), label=f"{key} [{data.unit(key)}]")
        plt.legend()
        plt.xlabel(f"Position [{data.unit('position')}]")
        plt.show()

    """

    def __init__(self, filename):
        self.filename = filename
        self.headers = ()
        self._units = []
        self._raw_data = None
        self._readAsciiFile()

    def _readAsciiFile(self):

        import numpy

        self._raw_data = numpy.loadtxt(self.filename, skiprows=1)
        with open(self.filename, "r") as fin:
            header = fin.readline()
            print(header)
            headers = []
            for h in header.split():
                if "centroid" == h:
                    continue
                elif "(" not in h:
                    headers.append(h)
                elif "unit(" in h:
                    self._units = tuple(h.split("(")[1][:-1].split(","))
                else:
                    base, main = h.split("(")
                    if base:
                        for k in main[:-1].split(","):
                            headers.append(f"{base}_{k}")
                    else:
                        headers.extend(main[:-1].split(","))
            self.headers = tuple(headers)

    def __iter__(self):
        return iter(self.keys())

    def keys(self):
        return self.headers

    def where(self, column):
        if column not in self.keys():
            raise ValueError(f"Wrong column name {column}")
        for i in range(len(self.keys())):
            if self.keys()[i] == column:
                return i

    def __getitem__(self, column):
        """
        Get the data of the column specified

        """
        index = self.where(column)
        return self._raw_data[:, index]

    def unit(self, column):
        """
        TODO gam-1
        """
        if "'" in column:
            return self._units[1]
        elif "/" in column:
            return ""
        elif column in ["time"]:
            return self._units[3]
        elif "phase" in column:
            return self._units[2]
        elif "energy" in column:
            return self._units[4]
        elif "gam-1" == column:
            return "GeV"
        return self._units[0]


class partran(dict):
    Read partran1.out files..

    This class can also read tracewin.out (same format)
Yngve Levinsen's avatar
Yngve Levinsen committed
    def __init__(self, filename):
        self.filename = filename
        self._readAsciiFile()

    def _readAsciiFile(self):

        import numpy

        stream = open(self.filename, "r")
Yngve Levinsen's avatar
Yngve Levinsen committed
        for i in range(10):
Yngve Levinsen's avatar
Yngve Levinsen committed
            line = stream.readline()
            if line.strip()[0] == "#":
Yngve Levinsen's avatar
Yngve Levinsen committed
                break
        self.columns = ["NUM"] + line.split()[1:]
Yngve Levinsen's avatar
Yngve Levinsen committed
        self.data = numpy.loadtxt(stream)
Yngve Levinsen's avatar
Yngve Levinsen committed
        self._dict = {}
Yngve Levinsen's avatar
Yngve Levinsen committed
        for i in range(len(self.columns)):
Yngve Levinsen's avatar
Yngve Levinsen committed
            self[self.columns[i]] = self.data[:, i]

    Class to read in the field map structures

    WARNING: Work in progress!!

    def __init__(self, filename):
Yngve Levinsen's avatar
Yngve Levinsen committed
        self._filename = filename
        self._load_data(filename)

Yngve Levinsen's avatar
Yngve Levinsen committed
    def _load_data(self, filename):
        import os
        import numpy

        if not os.path.isfile(filename):
            raise ValueError("Cannot find file {}".format(filename))
        fin = open(filename, "r")
Yngve Levinsen's avatar
Yngve Levinsen committed
        line = fin.readline().split()
        self.header = []
Yngve Levinsen's avatar
Yngve Levinsen committed
        self.start = []
        self.end = []
        numindexes = []
Yngve Levinsen's avatar
Yngve Levinsen committed
        while len(line) > 1:
            [self.header.append(float(i)) for i in line]
            numindexes.append(int(line[0]) + 1)
            if len(line) == 2:
                self.start.append(0.0)
Yngve Levinsen's avatar
Yngve Levinsen committed
                self.end.append(float(line[1]))
Yngve Levinsen's avatar
Yngve Levinsen committed
                self.start.append(float(line[1]))
                self.end.append(float(line[2]))
            line = fin.readline().split()
        if len(self.start) == 1:
            self.z = numpy.mgrid[self.start[0] : self.end[0] : numindexes[0] * 1j]
Yngve Levinsen's avatar
Yngve Levinsen committed
            print(self.z)
        elif len(self.start) == 2:
            self.z, self.x = numpy.mgrid[
                self.start[0] : self.end[0] : numindexes[0] * 1j,
                self.start[1] : self.end[1] : numindexes[1] * 1j,
        elif len(self.start) == 3:
            self.z, self.x, self.y = numpy.mgrid[
                self.start[0] : self.end[0] : numindexes[0] * 1j,
                self.start[1] : self.end[1] : numindexes[1] * 1j,
                self.start[2] : self.end[2] : numindexes[2] * 1j,
            ]
Yngve Levinsen's avatar
Yngve Levinsen committed
        self.norm = float(line[0])
        self.header.append(self.norm)
Yngve Levinsen's avatar
Yngve Levinsen committed
        self.map = numpy.loadtxt(fin).reshape(numindexes)
    def get_flat_fieldmap(self):
        totmapshape = 1
        for i in self.map.shape:
            totmapshape *= i
        return self.map.reshape(totmapshape)

    def interpolate(self, npoints: tuple, method="cubic"):
        """
        Interpolate the map into a new mesh
        Each value should be an integer with the number of mesh points in each dimension
        intervals should be tuple-like with same number of elements
        as the map dimension, e.g. [0.8,0.8] for 2D
        Can also be a float if you want same interpolation factor in all planes

        method can be 'linear', 'nearest' or 'cubic'
        import numpy
        from scipy.interpolate import griddata

        values = self.map.flatten()

        if len(self.start) == 1:
            points = self.z[:]
            self.z = numpy.mgrid[self.start[0] : self.end[0] : npoints[0] * 1j]
            self.map = griddata(points, values, self.z)
        if len(self.start) == 2:
            points = numpy.array([self.z.flatten(), self.x.flatten()]).transpose()
            self.z, self.x = numpy.mgrid[
                self.start[0] : self.end[0] : npoints[0] * 1j,
                self.start[1] : self.end[1] : npoints[1] * 1j,
            ]
            self.map = griddata(points, values, (self.z, self.x))
        if len(self.start) == 3:
Yngve Levinsen's avatar
Yngve Levinsen committed
            points = numpy.array([self.z.flatten(), self.x.flatten(), self.y.flatten()]).transpose()
            self.z, self.x, self.y = numpy.mgrid[
                self.start[0] : self.end[0] : npoints[0] * 1j,
                self.start[1] : self.end[1] : npoints[1] * 1j,
                self.start[2] : self.end[2] : npoints[2] * 1j,
            ]
            self.map = griddata(points, values, (self.z, self.x, self.y))
            self.header[0] = npoints[0] - 1
            self.header[2] = npoints[1] - 1
            self.header[5] = npoints[2] - 1
    def savemap(self, filename):
        fout = open(filename, "w")
Yngve Levinsen's avatar
Yngve Levinsen committed
        for n, s in zip(self.map.shape, self.size):
            fout.write("{} {}\n".format(n - 1, s))
        fout.write("{}\n".format(self.norm))
Yngve Levinsen's avatar
Yngve Levinsen committed
        totmapshape = 1
        for i in self.map.shape:
Yngve Levinsen's avatar
Yngve Levinsen committed
            totmapshape *= i
        data = self.map.reshape(totmapshape)
            fout.write("{}\n".format(j))
    """
    Read and modify TraceWin project files
    Example::

        p = project('SPK.ini')
        for diff in p.compare_to('MEBT.ini'):
            print(diff)
        p.set('main:beam1_energy', 89e6)
        p.save()
    """

    def __init__(self, project_fname=None, settings_fname=None):
        import json
        import pkg_resources

        if settings_fname is None:
            self._refdict, self._rules = json.loads(pkg_resources.resource_string(__name__, "data/tw_project_file_reverse_engineered.json"))
            self._refdict, self._rules = json.loads(open(self._settings_fname, "r").read())

        self._settings_fname = settings_fname
        self._project_fname = project_fname

        self._dict = {}
        if self._project_fname is not None:
            self._read_settings()

    def _read_settings(self):
        import struct
        import textwrap

        with open(self._project_fname, "rb") as f:
            hexlist = textwrap.wrap(f.read().hex(), 2)
        for key in self._refdict:
            o = self._refdict[key]
            if o[1] == "bool:list":
                vals = [struct.unpack("?", b"".fromhex(hexlist[i]))[0] for i in o[0]]
                if vals.count(True) != 1:
                    raise ValueError(f"Did not find {key} to be set correctly")
                self._dict[key] = o[-1][vals.index(True)]
            elif o[1] == "int:list":
                current = "".join(hexlist[o[0] : o[0] + o[2]])
                value = struct.unpack("i", b"".fromhex(current))[0]
                found = False
                for k, v in o[3].items():
                    if v == value:
                        self._dict[key] = k
                        found = True
                if not found:
                    raise ValueError("Unknown setting {value} for {key}")
Yngve Levinsen's avatar
Yngve Levinsen committed
            elif o[1] == "str":
                string = self._find_string(hexlist, o[0], o[2])
                self._dict[key] = bytes.fromhex(string).decode("utf-8")
            else:
                current = "".join(hexlist[o[0] : o[0] + o[2]])
                if o[1] in ["d", "f", "i", "?"]:
                    self._dict[key] = struct.unpack(o[1], b"".fromhex(current))[0]
                else:
                    raise TypeError(f"Unknown type for {key}: {o[1]}")
Yngve Levinsen's avatar
Yngve Levinsen committed
        self.check_rule(fix_if_possible=True)

    def _find_string(self, hexlist, start, maxlen):
        string = ""
        i = 0
        while hexlist[start + i] != "00" and i < maxlen:
            string += hexlist[start + i]
            i += 1
        return string

    def print_settings(self, settings=None):
        """
        Print the settings given, or all by default

        :param settings: List of the settings to print
        """
        if settings is None:
            settings = self.keys()

        for setting in settings:
            print(setting, self._dict[setting])

    def keys(self):
        return self._dict.keys()

    def get(self, parameter):
        """
        Get the setting of the parameter
        """
        return self._dict[parameter]

    def get_type(self, parameter):
        """
        Get the type of parameter
        as specified in the reference file

        d : double value
        i : integer value,
        int:list : integer representing a list selection
        bool:list : booleans representing a list selection

        For int:list and bool:list, recommend to use get_options()
        to figure out how to set as a user.
        """
        return self._refdict[parameter][1]

    def get_options(self, parameter):
        """
        Get the possible options for parameter
        as specified in the reference file
        """
        if isinstance(self._refdict[parameter][-1], dict):
            return list(self._refdict[parameter][-1].keys())
        else:
            return self._refdict[parameter][-1]
    def set(self, parameter, value):
        """
        Set the new value for parameter
        """
        current = self.get(parameter)
        if isinstance(current, bool):
            if not isinstance(value, bool):
                raise ValueError(f"{parameter} should be True or False")
        elif isinstance(current, (float, int)):
            if not isinstance(value, (float, int)):
                raise ValueError(f"{parameter} should be a number")
        elif self.get_type(parameter) in ["bool:list", "int:list"]:
            opts = self.get_options(parameter)
            if value not in opts:
                raise ValueError(f"{parameter} should be one of {opts}")
        self._dict[parameter] = value

    def _check_rule_same_sign(self, variables, explanation, fail_on_err):

        v1 = self.get(variables[0])
        for i in range(1, len(variables)):
            v2 = self.get(variables[i])
            if abs(v1 + v2) != abs(v1) + abs(v2):
                errmsg = f"{variables[i]} and {variables[0]} have opposite signs\nExplanation/logic: {explanation}"
                if fail_on_err:
                    raise ValueError(errmsg)
                else:
                    print(errmsg)

Yngve Levinsen's avatar
Yngve Levinsen committed
    def _check_rule_endswith(self, end, variables, explanation, fail_on_err, fix_if_possible):

        for var in variables:
            value = self.get(var)
            if not value.endswith(end):
                errmsg = f"{var} should end with {end}\nExplanation/logic: {explanation}"
                if fail_on_err:
                    raise ValueError(errmsg)
                elif fix_if_possible:
                    self.set(var, value + end)
                else:
                    print(errmsg)

    def _unset_string_rules(self, variable, value):
        for r in self._rules:
            if variable in r[1]:
                if r[0].split(":")[0] == "endswith":
                    end = r[0].split(":")[1]
                    if value.endswith(end):
                        value = value[: -len(end)]
        return value

    def check_rule(self, rule=None, fail_on_err=False, fix_if_possible=False):
        """
        Validate that we still obey the rule
        if rule is not given, check all rules
        """
        if rule is None:
            rules = self._rules
        else:
            rules = [rule]

        for r in rules:
            if r[0] == "same-sign":
                self._check_rule_same_sign(r[1], r[2], fail_on_err)
Yngve Levinsen's avatar
Yngve Levinsen committed
            elif r[0].split(":")[0] == "endswith":
                self._check_rule_endswith(r[0].split(":")[1], r[1], r[2], fail_on_err, fix_if_possible)
            else:
                raise TypeError(f"Unknown rule {r[0]}")

    def save(self, fname=None):
        """
        Save the project file

        If fname not given, overwrite original file
        """
        import struct
        from textwrap import wrap

        for rule in self._rules:
            self.check_rule(rule, fail_on_err=True)

        if fname is None:
            fname = self._project_fname

        with open(self._project_fname, "rb") as f:
            hexlist = wrap(f.read().hex(), 2)

        for key in self._dict:
            o = self._refdict[key]
            v = self._dict[key]

            if o[1] == "bool:list":
                for i, val in zip(o[0], o[-1]):
                    if v == val:
                        t = True
                    else:
                        t = False
                    hexlist[i] = struct.pack("?", t).hex()
            elif o[1] == "int:list":
                v = o[-1][v]
                v = wrap(struct.pack("i", v).hex(), 2)
                for i in range(len(v)):
                    hexlist[o[0] + i] = v[i]
Yngve Levinsen's avatar
Yngve Levinsen committed
            elif o[1] == "str":
                v = self._unset_string_rules(key, v)
                h = wrap(v.encode(encoding="utf_8").hex(), 2)
                h.append("00")
                for i in range(len(h)):
                    hexlist[o[0] + i] = h[i]
                v = wrap(struct.pack(o[1], v).hex(), 2)
                for i in range(len(v)):
                    hexlist[o[0] + i] = v[i]
Yngve Levinsen's avatar
Yngve Levinsen committed
        with open(fname, "wb") as fout:
            fout.write(bytes.fromhex("".join(hexlist)))

    def compare_to(self, other):
        """
        Compare the settings of this file to a
        different project file

        :param other: project object, or file path to other project file
        """
        diffs = []

        if isinstance(other, str):
            other = project(other)

        for key in self.keys():
            if self.get(key) != other.get(key):
                diffs.append([key, self.get(key), other.get(key)])
        return diffs