Skip to content
Snippets Groups Projects
lib_tw.py 59.3 KiB
Newer Older
    Classes and functions for TraceWin

    - Note for clean-up

      * Old LATTICE and MATCH classes can be deleted once trajectory
        correction is established with new classes.
from __future__ import print_function

# ---- Lib
Yngve Levinsen's avatar
Yngve Levinsen committed
import numpy
from struct import pack, unpack
from itertools import chain

from subprocess import check_output
Yngve Levinsen's avatar
Yngve Levinsen committed
import os
import sys
Yngve Levinsen's avatar
Yngve Levinsen committed
from . import lib_tw_elem
# -------- Classes
# ---- Lattice and project


class LATTICE:
    """
    """

    def __init__(self, file_name_lat, file_name_fmap=[], freq=352.21, gamma=1.0):
        """
        :param file_name_lat: name of lattice file
        :type file_name_lat: str
        :param file_name_fmap: list of field map file(-s)
        :type file_name_fmap: list or str
        :param freq: RF frequency
        :type freq: float
        :param gamma: relativistic gamma
        :type gamma: float

        # In case file_name_fmap is str
        if isinstance(file_name_fmap, str):
            file_name_fmap = [file_name_fmap]
        # Elem/comm class dict
        dic_cls = {
Yngve Levinsen's avatar
Yngve Levinsen committed
            "DRIFT": lib_tw_elem.DRIFT,
            "QUAD": lib_tw_elem.QUAD,
            "THIN_STEERING": lib_tw_elem.THIN_STEERING,
            "GAP": lib_tw_elem.GAP,
            "DTL_CEL": lib_tw_elem.DTL_CEL,
Yngve Levinsen's avatar
Yngve Levinsen committed
            "BEND": lib_tw_elem.BEND,
            "EDGE": lib_tw_elem.EDGE,
            "APERTURE": lib_tw_elem.APERTURE,
            "DIAG_POSITION": lib_tw_elem.DIAG_POSITION,
Yngve Levinsen's avatar
Yngve Levinsen committed
            "STEERER": lib_tw_elem.STEERER,
            "CHOPPER": lib_tw_elem.CHOPPER,
Yngve Levinsen's avatar
Yngve Levinsen committed
            "ADJUST": lib_tw_elem.ADJUST,
            "FREQ": lib_tw_elem.FREQ,
            "MARKER": lib_tw_elem.MARKER,
Yngve Levinsen's avatar
Yngve Levinsen committed
            "ERROR_BEAM_STAT": lib_tw_elem.ERROR_BEAM_STAT,
            "ERROR_BEAM_DYN": lib_tw_elem.ERROR_BEAM_DYN,
            "ERROR_QUAD_NCPL_STAT": lib_tw_elem.ERROR_QUAD_NCPL_STAT,
            "ERROR_QUAD_CPL_STAT": lib_tw_elem.ERROR_QUAD_CPL_STAT,
            "ERROR_CAV_NCPL_STAT": lib_tw_elem.ERROR_CAV_NCPL_STAT,
            "ERROR_CAV_NCPL_DYN": lib_tw_elem.ERROR_CAV_NCPL_DYN,
            "ERROR_CAV_CPL_STAT": lib_tw_elem.ERROR_CAV_CPL_STAT,
            "ERROR_CAV_CPL_DYN": lib_tw_elem.ERROR_CAV_CPL_DYN,
            "ERROR_STAT_FILE": lib_tw_elem.ERROR_STAT_FILE,
            "ERROR_DYN_FILE": lib_tw_elem.ERROR_DYN_FILE,
        # Field map dict
        dic_fmap = {}
        for file_name_fmap_i in file_name_fmap:
            name_fmap = ".".join(
                file_name_fmap_i.split("/")[-1].split(".")[:-1]
            )  # Remove / and extension
Yngve Levinsen's avatar
Yngve Levinsen committed
            dic_fmap[name_fmap] = lib_tw_elem.FIELD_MAP_DATA(file_name_fmap_i)
        # In case the lattice file is in a different folder:
        basefolder = os.path.dirname(file_name_lat)
            _update_field_map_dict(dic_fmap, basefolder)

        # Go through the lat file
        with open(file_name_lat) as file:
            lst = []
            for lin in file:
                lin = lin.partition(";")[0]  # Remove comment
                if lin.split() != []:  # Remove empty line
                    # Split a line
                    if ":" in lin:
                        name = lin.partition(":")[0].split()[0]
                        typ = lin.partition(":")[2].split()[0].upper()
                        para = lin.partition(":")[2].split()[1:]
                    else:
                        name = ""
                        typ = lin.split()[0].upper()
                        para = lin.split()[1:]
                    # Map to a class
                    if typ == "FIELD_MAP":
Yngve Levinsen's avatar
Yngve Levinsen committed
                        lst.append(lib_tw_elem.FIELD_MAP(name, typ, para, dic_fmap))
                    elif typ in dic_cls.keys():
                        lst.append(dic_cls[typ](name, typ, para))
                    elif "DIAG" in typ:
Yngve Levinsen's avatar
Yngve Levinsen committed
                        lst.append(lib_tw_elem.DIAG(name, typ, para))
Yngve Levinsen's avatar
Yngve Levinsen committed
                        lst.append(lib_tw_elem.COMM(name, typ, para))

                    # in case of field map path, update dictionary with new path
                    if typ == "FIELD_MAP_PATH":
                        _update_field_map_dict(dic_fmap, para[0])
                    # Break the loop
                    if typ == "END":
                        break

        # Instances
        self.gamma = gamma
        self.freq = freq
        self.lst = lst
        self.fmap = dic_fmap  # Maybe not needed ...

        # Assign idx, idx_elem, s, freq, apt
        self.update_idx()

    def get_correctors_idx(self, i):
        """
        Get correctors associated to corrector index i
        found = False
        correctors = []
        for element in self.lst:
            if found:
                # WARNING I worry that there could be an inactive comment/element between the ADJUST and actual corrector
                logging.debug(
                    "Found element {} for corrector family {}".format(element.typ, i)
                )
                correctors.append(element)
                found = False
Yngve Levinsen's avatar
Yngve Levinsen committed
            if isinstance(element, lib_tw_elem.COMM):
                if element.typ in ["ADJUST", "ADJUST_STEERER"]:
                    if int(element.para[0]) == i:  # next element is the corrector
                        found = True
    def get_elem_idx(self, i):
        """
        Get a TraceWin index number

        Note: We start counting from 0, TW starts from 1
            if element.idx_elem == i - 1:
    def get_steerer_for(self, idx_elem):
        """
        Returns the steerer object for an element (e.g. quad)
        """
        previous = None
            if element.idx_elem + 1 == idx_elem:
                if previous.typ == "STEERER":
            previous = element
    def update_idx(self):

        # Assign/update idx, idx_elem, s, freq
        for i in range(len(self.lst)):
            if i == 0:
                self.lst[0].idx = -1
                self.lst[0].idx_elem = -1
                self.lst[0].s = 0.0
                self.lst[0].freq = self.freq
            if i != 0:
                self.lst[i].idx = self.lst[i - 1].idx
                self.lst[i].idx_elem = self.lst[i - 1].idx_elem
                self.lst[i].s = self.lst[i - 1].s
                self.lst[i].freq = self.lst[i - 1].freq
            self.lst[i].update_idx()
        # Assign apt (using apt of the previous elem for diag elem)
        for i in range(len(self.lst)):
            try:
Yngve Levinsen's avatar
Yngve Levinsen committed
                if self.lst[i].apt is None:
                    if self.lst[i].idx_elem == 0:
                        for k in range(1, len(self.lst)):
Yngve Levinsen's avatar
Yngve Levinsen committed
                                if self.lst[k].apt is not None:
                                    self.lst[i].apt = self.lst[k].apt
                                    break
                            except IndexError:
                    else:
                        for k in range(i)[::-1]:
                            try:
                                self.lst[i].apt = self.lst[k].apt
                                break
                            except IndexError:
            except IndexError:
    def update_gamma(self):

        for i in range(len(self.lst)):
            if i == 0:
                self.lst[0].gamma = self.gamma
            if i != 0:
                self.lst[i].gamma = self.lst[i - 1].gamma
            try:
                self.lst[i].update_gamma()
            except IndexError:
    def update_st(self, file_name):
        """
            Assign/update steerer values from "Steerer_Values.txt".
        # Extract BLx and BLy from the file
        with open(file_name, "r") as file:
            BLx = {}
            BLy = {}
            for lin in file:
                lin = lin.split()[3:]
                for i in range(len(lin)):
                    if lin[i] == ":":
                        idx_elem = int(lin[i - 1]) - 1  # "-1" since idx=0,1,2,...
                    if lin[i] == "BLx=":
                        BLx[idx_elem] = float(lin[i + 1])
                    if lin[i] == "BLy=":
                        BLy[idx_elem] = float(lin[i + 1])
        # Assign/update steerers
        for i in range(len(self.lst)):
                idx_elem = self.lst[i].idx_elem
                if self.lst[i].typ == "THIN_STEERING":
                    self.lst[i].BLx = BLx[idx_elem]
                    self.lst[i].BLy = BLy[idx_elem]
                if self.lst[i].typ != "THIN_STEERING":
                    for k in range(i)[::-1]:
                        if self.lst[k].typ == "STEERER":
                            self.lst[k].Bx = BLx[idx_elem] / self.lst[i].L
                            self.lst[k].By = BLy[idx_elem] / self.lst[i].L
                            break
    def update_adj(self, file_name="Adjusted_Values.txt"):
        """
            Assign/update correction values from "Adjusted_Values.txt".

            WARNING: many corner cases, what happens if multiple adjust
                     commands have corrected same element for example?
                     Use with care!
        """
        with open(file_name, "r") as file:
            # First we read in all corrections into dictionaries
            values = {}
            counts = {}
                lin = lin.split()
                i = int(lin[1][1:-1])
                if i == "ERROR" and lin[0] == "BEAM":
                settings = {}
                for j in range(len(lin[2:]) / 3):
                    k = int(lin[2 + 3 * j])
                    val = float(lin[4 + 3 * j])
                    if k in settings:
                        settings[k].append(val)
                    else:
                        settings[k] = [val]
                values[i] = settings
                    counts[j] = 0

        # now we will do all the ADJUST_STEERER ones
        corr_next = False
Yngve Levinsen's avatar
Yngve Levinsen committed
            if isinstance(el, lib_tw_elem.COMM) and el.typ == "ADJUST_STEERER":
                i = int(el.para[0])
                if i in values:
                    corr_next = values[i]
Yngve Levinsen's avatar
Yngve Levinsen committed
                if isinstance(el, lib_tw_elem.STEERER):
                    vals = corr_next.values()[0]
                    el.Bx = vals[0]
                    el.By = vals[1]
                corr_next = False
        corr_next = False
        # The TraceWin index number of the current element:
        current = -1
            if el.idx_elem != -1:
                current = el.idx_elem + 1
Yngve Levinsen's avatar
Yngve Levinsen committed
            if isinstance(el, lib_tw_elem.COMM) and el.typ == "ADJUST":
                # The index of the corrector scheme
                i = int(el.para[0])
                # the parameter column to vary:
                j = int(el.para[1]) - 1
                # The TraceWin element index of the element we will vary:
                k = current + 1
                # This corrector might not be used:
                if i in values:
                    # We will correct the next active element in lattice:
                    value = values[i][k][counts[k]]
                    # We have now used one value of the total in the current corrector:
                    counts[k] += 1
                    vary = self.get_elem_idx(current + 1)
                    vary.para[j] = value
    def get_tw(self, file_name):
        with open(file_name, "w") as file:
            for lat_i in self.lst:
                file.write(lat_i.get_tw() + "\n")
    def get_madx(self, file_name_elem="elem.madx", file_name_seq="seq.madx"):
        if self.lst[-1].gamma == 1.0:
            self.update_gamma()  # Assign gamma, if not done yet
        with open(file_name_elem, "w") as fname:
            for lat_i in self.lst:
                    fname.write(lat_i.get_madx() + "\n")
                except AttributeError:
                    pass
        with open(file_name_elem, "r") as fname:
            lst_name = [lin.split(":")[0] for lin in fname]
        with open(file_name_seq, "w") as fname:
            fname.write("linac:line=({});\n".format(",".join(lst_name)))
    def get_fluka(self, file_name="elem.dat"):
        if self.lst[-1].gamma == 1.0:
            self.update_gamma()  # Assign gamma, if not done yet
        with open(file_name, "w") as fname:
            for lat_i in self.lst:
                    fname.write(lat_i.get_fluka() + "\n")
                except AttributeError:
                    pass
    def get_mars(self, file_name="elem.dat"):
        if self.lst[-1].gamma == 1.0:
            self.update_gamma()  # Assign gamma, if not done yet
        with open(file_name, "w") as fname:
            for lat_i in self.lst:
                    fname.write(lat_i.get_mars() + "\n")
                except AttributeError:
                    pass
    def get_bdsim(self, output_folder="bdsim"):
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
        file_name = os.path.join(output_folder, "bdsim.dat")
        if self.lst[-1].gamma == 1.0:
            self.update_gamma()  # Assign gamma, if not done yet
        with open(file_name, "w") as fname:
            for fmap in self.fmap:
                fname.write(
                    self.fmap[fmap].get_bdsim(os.path.join(output_folder, fmap)) + "\n"
                )
            for lat_i in self.lst:
                if hasattr(lat_i, "get_bdsim"):
                    fname.write(lat_i.get_bdsim() + "\n")
class PROJECT:
        - This is for running multiple simulations 1-by-1 under 1 project.
        - Maybe not very useful...
        2015.10.15
    def __init__(self, file_name_proj="proj.ini"):

        # Instances (Add more options as needed.)
        self.file_name_proj = file_name_proj
        self.file_name_lat = None
        self.path_cal = None
        self.seed = None
        self.flag_hide = None

    def exe(self):

        opt_exe = "TraceWin64_noX11 " + self.file_name_proj
        if self.file_name_lat is not None:
            opt_exe += " dat_file=" + self.file_name_lat
        if self.path_cal is not None:
            opt_exe += " path_cal=" + self.path_cal
        if self.seed is not None:
            opt_exe += " random_seed=" + self.seed
        if self.flag_hide is not None:
            opt_exe += " hide"
Yngve Levinsen's avatar
Yngve Levinsen committed
        # if self.path_cal!=None:
        #     if os.isdir(self.path_cal)==False: system('mkdir '+self.path_cal)

# ---- Data related


class PARTRAN:
        - The list not complete. Add parameters as needed.
        - 2016.02.17: Changed how to identify the line of indices.
        - 2016.02.17: Added a logic to avoid #/0 for LEBT.
    """

    def __init__(self, file_name):

        # Consts to convert phs to z.
        c = 2.99792458
        freq = 352.21

        # Extract indices.
        with open(file_name) as file:
            for lin in file.readlines():
                lin = lin.split()
                if "##" in lin[0]:
                    idx_s = lin.index("z(m)")
                    idx_gamma = lin.index("gama-1")
                    idx_x = lin.index("x0")
                    idx_y = lin.index("y0")
                    idx_phs = lin.index("p0")
                    idx_sigx = lin.index("SizeX")
                    idx_sigy = lin.index("SizeY")
                    idx_sigz = lin.index("SizeZ")
                    idx_sigp = lin.index("SizeP")
                    idx_alfx = lin.index("sxx'")
                    idx_alfy = lin.index("syy'")
                    idx_alfz = lin.index("szdp")
                    idx_epsx = lin.index("ex")
                    idx_epsy = lin.index("ey")
                    idx_epsz = lin.index("ezdp")
                    idx_epsp = lin.index("ep")
                    idx_halx = lin.index("hx")
                    idx_haly = lin.index("hy")
                    idx_halp = lin.index("hp")
                    idx_Nptcl = lin.index("npart")
                    idx_loss = lin.index("Powlost")
                    break
        # Extract data.
        with open(file_name) as file:
            data = []
            flag = 0
            for lin in file.readlines():
                lin = lin.split()
                if flag == 1:
                    data.append(map(float, lin))
                if "##" in lin[0]:
                    flag = 1
            data = numpy.array(data).transpose()

        # Instances
        self.s = data[idx_s]
        self.x = data[idx_x]
        self.y = data[idx_y]
        self.phs = data[idx_phs]
        self.sigx = data[idx_sigx]
        self.sigy = data[idx_sigy]
        self.sigz = data[idx_sigz]
        self.sigp = data[idx_sigp]
        self.epsx = data[idx_epsx]
        self.epsy = data[idx_epsy]
        self.epsz = data[idx_epsz]
        self.epsp = data[idx_epsp]
        self.halx = data[idx_halx]
        self.haly = data[idx_haly]
        self.halz = data[idx_halp]
        self.halp = data[idx_halp]
        self.Nptcl = data[idx_Nptcl]
        self.loss = data[idx_loss]
            if self.epsx[i] == 0.0:
                self.epsx[i] = numpy.inf
            if self.epsy[i] == 0.0:
                self.epsy[i] = numpy.inf
            if self.epsz[i] == 0.0:
                self.epsz[i] = numpy.inf
            if self.epsp[i] == 0.0:
                self.epsp[i] = numpy.inf
        # Additional instances
        self.gamma = data[idx_gamma] + 1.0
        self.beta = numpy.sqrt(1.0 - 1.0 / self.gamma ** 2)
        self.z = -self.phs * self.beta * (c / freq * 1e5) / 360.0
        self.betx = self.sigx ** 2 / self.epsx * self.beta * self.gamma
        self.bety = self.sigy ** 2 / self.epsy * self.beta * self.gamma
        self.betz = self.sigz ** 2 / self.epsz * self.beta * self.gamma ** 3
        self.betp = self.sigp ** 2 / self.epsp
        self.alfx = -data[idx_alfx] / self.epsx * self.beta * self.gamma
        self.alfy = -data[idx_alfy] / self.epsy * self.beta * self.gamma
        self.alfz = -data[idx_alfz] / self.epsz * self.beta * self.gamma ** 3
        self.alfp = -self.alfz
            if self.epsx[i] == numpy.inf:
                self.epsx[i] = 0.0
            if self.epsy[i] == numpy.inf:
                self.epsy[i] = 0.0
            if self.epsz[i] == numpy.inf:
                self.epsz[i] = 0.0
            if self.epsp[i] == numpy.inf:
                self.epsp[i] = 0.0
        # Convert to list (not necessary?)
        self.s = self.s.tolist()
        self.gamma = self.gamma.tolist()
        self.beta = self.beta.tolist()
        self.x = self.x.tolist()
        self.y = self.y.tolist()
        self.z = self.z.tolist()
        self.phs = self.phs.tolist()
        self.sigx = self.sigx.tolist()
        self.sigy = self.sigy.tolist()
        self.sigz = self.sigz.tolist()
        self.sigp = self.sigp.tolist()
        self.betx = self.betx.tolist()
        self.bety = self.bety.tolist()
        self.betz = self.betz.tolist()
        self.betp = self.betp.tolist()
        self.alfx = self.alfx.tolist()
        self.alfy = self.alfy.tolist()
        self.alfz = self.alfz.tolist()
        self.alfp = self.alfp.tolist()
        self.epsx = self.epsx.tolist()
        self.epsy = self.epsy.tolist()
        self.epsz = self.epsz.tolist()
        self.epsp = self.epsp.tolist()
        self.halx = self.halx.tolist()
        self.haly = self.haly.tolist()
        self.halz = self.halz.tolist()
        self.halp = self.halp.tolist()
        self.Nptcl = self.Nptcl.tolist()
        self.loss = self.loss.tolist()

    def loss_den(self, file_name_dt="", dlt_dt=5e-6):

        return loss_elem2den(self.s, self.loss, file_name_dt, dlt_dt)
        Class for a TraceWin's .dst file.
        - TraceWin seems using beta and gamma for each particle
          so the conversion to (z,z') is based on this assumption.
        2015.10.06
    """

    def __init__(
        self, file_name, unit_x="cm", unit_px="rad", unit_z="rad", unit_pz="MeV"
    ):
        c = 2.99792458

        # Read the file
        with open(file_name) as file:
            numpy.fromfile(file, dtype=numpy.uint8, count=2)
Yngve Levinsen's avatar
Yngve Levinsen committed
            Nptcl = numpy.fromfile(file, dtype=numpy.uint32, count=1)[0]
            Ibeam = numpy.fromfile(file, dtype=numpy.float64, count=1)[0]
            freq = numpy.fromfile(file, dtype=numpy.float64, count=1)[0]
            numpy.fromfile(file, dtype=numpy.uint8, count=1)
Yngve Levinsen's avatar
Yngve Levinsen committed
                numpy.fromfile(file, dtype=numpy.float64, count=Nptcl * 6)
                .reshape(Nptcl, 6)
                .transpose()
            )
Yngve Levinsen's avatar
Yngve Levinsen committed
            mass = numpy.fromfile(file, dtype=numpy.float64, count=1)[0]
        # Adjust units
        gamma = 1.0 + x[5] / mass
        beta = numpy.sqrt(1 - 1 / gamma ** 2)
        if unit_x == "mm":
            x[0] = x[0] * 1e1
            x[2] = x[2] * 1e1
        if unit_px == "mrad":
            x[1] = x[1] * 1e3
            x[3] = x[3] * 1e3
        if unit_z == "deg":
Yngve Levinsen's avatar
Yngve Levinsen committed
            x[4] = x[4] * 180 / numpy.pi
        if unit_z == "mm":
Yngve Levinsen's avatar
Yngve Levinsen committed
            x[4] = -x[4] * c * beta / (2 * numpy.pi * freq) * 1e5
        if unit_pz == "mrad":
Yngve Levinsen's avatar
Yngve Levinsen committed
            x[5] = (x[5] - numpy.mean(x[5])) / (mass * beta ** 2 * gamma ** 3) * 1e3

        # Instances
        self.x = x.transpose()
        self.mass = mass
        self.freq = freq
        self.Ibeam = Ibeam


class DENSITY:
        - Note instances are not identical for Nrun=1 and Nrun>1.
        - Be careful with # of steps for an envelope file.
        - When Nrun>1, ave and rms in a file are sum and squared sum.
        - Double check before a production !!!!
        - Dim of arrays:

          Nelem(Nidx): idx_elem, s, Nptcl, Ibeam

          4 x Nelem(Nidx): apt                         # x, y, dx, dy
          3 x Nelem(Nidx): accpt                       # phs+, phs-, ken
          7 x Nelem(Nidx): cent_(..), sig, xmax, xmin  # x, y, phs, ken, r, z, dpp
          3 x Nelem(Nidx): eps                         # x, y, phs

          Nrun x Nelem: loss
          Nelem       : loss_num_(..), loss_pow_(..)
Yngve Levinsen's avatar
Yngve Levinsen committed

          Nelem x Nstep: den
          * 2016.03.29: Adapted to ver 9 (apt includes shifts)
    """

    def __init__(self, file_name):

        # -- Empty arrays

        idx_elem = []
        s = []
        apt = []
        accpt = []
        Nptcl = []
        Ibeam = []
        cent_ave = []
        cent_rms = []
        cent_max = []
        cent_min = []
        sig_ave = []
        sig_rms = []
        eps_ave = []
        eps_rms = []
        xmax = []
        xmin = []
        loss_num = []
        loss_pow = []
        loss_num_ave = []
        loss_pow_ave = []
        loss_num_rms = []
        loss_pow_rms = []
        loss_num_max = []
        loss_pow_max = []
        loss_num_min = []
        loss_pow_min = []
        den = []
        den_pow = []

        # -- Extract data
        with open(file_name) as file:
            while True:
                try:
                    # Partran and envelope
                    ver, year, flag_long = numpy.fromfile(
                        file, dtype=numpy.uint16, count=3
                    )
                    Nrun = numpy.fromfile(file, dtype=numpy.uint32, count=1)[0]
                    idx_elem.append(
                        numpy.fromfile(file, dtype=numpy.uint32, count=1)[0]
                    )
                    Ibeam.append(numpy.fromfile(file, dtype=numpy.float32, count=1)[0])
                    s.append(numpy.fromfile(file, dtype=numpy.float32, count=1)[0])
                    if ver >= 9:
                        apt.append(numpy.fromfile(file, dtype=numpy.float32, count=4))
                        apt.append(numpy.fromfile(file, dtype=numpy.float32, count=2))
                    Nstep = numpy.fromfile(file, dtype=numpy.uint32, count=1)[0]
                    cent_ave.append(numpy.fromfile(file, dtype=numpy.float32, count=7))
                    cent_rms.append(numpy.fromfile(file, dtype=numpy.float32, count=7))
                    xmax.append(numpy.fromfile(file, dtype=numpy.float32, count=7))
                    xmin.append(numpy.fromfile(file, dtype=numpy.float32, count=7))
                    if ver > 5:
                        sig_ave.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=7)
                        )
                        sig_rms.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=7)
                        )
                    if ver >= 6:
                        cent_min.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=7)
                        )
                        cent_max.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=7)
                        )
                    if ver >= 7:
                        eps_ave.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=3)
                        )
                        eps_rms.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=3)
                        )
                    if ver >= 8:
                        accpt.append(numpy.fromfile(file, dtype=numpy.float32, count=3))
                    Nptcl.append(numpy.fromfile(file, dtype=numpy.uint64, count=1)[0])
                    # Partran only
                    if Nptcl[-1] > 0:
                        loss_num.append([])
                        loss_pow.append([])
                        den.append([])
                        den_pow.append([])
                        for n in range(Nrun):
                            loss_num[-1].append(
                                numpy.fromfile(file, dtype=numpy.uint64, count=1)[0]
                            )
                            loss_pow[-1].append(
                                numpy.fromfile(file, dtype=numpy.float32, count=1)[0]
                        loss_num_ave.append(sum(loss_num[-1]))
                        loss_num_rms.append(
                            numpy.fromfile(file, dtype=numpy.uint64, count=1)[0]
                        )
                        loss_num_min.append(
                            numpy.fromfile(file, dtype=numpy.uint64, count=1)[0]
                        )
                        loss_num_max.append(
                            numpy.fromfile(file, dtype=numpy.uint64, count=1)[0]
                        )
Yngve Levinsen's avatar
Yngve Levinsen committed
                        loss_pow_ave.append(sum(loss_pow[-1]))
                        loss_pow_rms.append(
                            numpy.fromfile(file, dtype=numpy.float64, count=1)[0]
                        )
                        loss_pow_min.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=1)[0]
                        )
                        loss_pow_max.append(
                            numpy.fromfile(file, dtype=numpy.float32, count=1)[0]
                        )
                        for k in range(7):
                            if flag_long == 1:
                                den[-1].append(
                                    numpy.fromfile(
                                        file, dtype=numpy.uint64, count=Nstep
                                    )
                                )
                            else:
                                den[-1].append(
                                    numpy.fromfile(
                                        file, dtype=numpy.uint32, count=Nstep
                                    )
                                )
                        if Ibeam[-1] > 0:
                            for k in range(3):
                                den_pow[-1].append(
                                    numpy.fromfile(
                                        file, dtype=numpy.float32, count=Nstep
                                    )
                                )
                    # print Nrun,Nptcl[-1],idx_elem[-1]  # Diag
                except IndexError:
        # -- Reshape arrays

        apt = numpy.swapaxes(apt, 1, 0)
        accpt = numpy.swapaxes(accpt, 1, 0)
        cent_ave = numpy.swapaxes(cent_ave, 1, 0)
        cent_rms = numpy.swapaxes(cent_rms, 1, 0)
        cent_max = numpy.swapaxes(cent_max, 1, 0)
        cent_min = numpy.swapaxes(cent_min, 1, 0)
        sig_ave = numpy.swapaxes(sig_ave, 1, 0)
        sig_rms = numpy.swapaxes(sig_rms, 1, 0)
        eps_ave = numpy.swapaxes(eps_ave, 1, 0)
        eps_rms = numpy.swapaxes(eps_rms, 1, 0)
        xmax = numpy.swapaxes(xmax, 1, 0)
        xmin = numpy.swapaxes(xmin, 1, 0)
        if Nptcl[0] > 0:
            loss_num = numpy.swapaxes(loss_num, 1, 0)
            loss_pow = numpy.swapaxes(loss_pow, 1, 0)
            den = numpy.swapaxes(den, 1, 0)
            den_pow = numpy.swapaxes(den_pow, 1, 0)

        # -- Take care ave and rms

        cent_ave = cent_ave / Nrun
        cent_rms = numpy.sqrt(cent_rms / Nrun)
        sig_ave = sig_ave / Nrun
        sig_rms = numpy.sqrt(sig_rms / Nrun)
        eps_ave = eps_ave / Nrun
        eps_rms = numpy.sqrt(eps_rms / Nrun)
        if Nptcl[0] > 0:
            loss_num_ave = 1.0 * numpy.array(loss_num_ave) / Nrun
            loss_num_rms = numpy.sqrt(1.0 * numpy.array(loss_num_rms) / Nrun)
            loss_pow_ave = numpy.array(loss_pow_ave) / Nrun
            loss_pow_rms = numpy.sqrt(numpy.array(loss_pow_rms) / Nrun)

        # -- Change units, m => mm, pi-m-rad => pi-mm-mrad

        apt *= 1e3
        eps_ave *= 1e6
        eps_rms *= 1e6
        for k in (0, 1, 4, 5):
            cent_ave[k] *= 1e3
            cent_rms[k] *= 1e3
            cent_max[k] *= 1e3
            cent_min[k] *= 1e3
            sig_ave[k] *= 1e3
            sig_rms[k] *= 1e3
            xmax[k] *= 1e3
            xmin[k] *= 1e3

        # -- Define std (around to avoid sqrt(-eps))

        if Nrun > 1:
            cent_std = numpy.sqrt(numpy.around(cent_rms ** 2 - cent_ave ** 2, 12))
            sig_std = numpy.sqrt(numpy.around(sig_rms ** 2 - sig_ave ** 2, 12))
            eps_std = numpy.sqrt(numpy.around(eps_rms ** 2 - eps_ave ** 2, 12))
            cent_std = numpy.nan_to_num(cent_std)  # Replace nan with 0
            sig_std = numpy.nan_to_num(sig_std)  # Replace nan with 0
            eps_std = numpy.nan_to_num(eps_std)  # Replace nan with 0
            if Nptcl[0] > 0:
                loss_num_std = numpy.sqrt(
                    numpy.around(loss_num_rms ** 2 - loss_num_ave ** 2, 16)
                )
                loss_pow_std = numpy.sqrt(
                    numpy.around(loss_pow_rms ** 2 - loss_pow_ave ** 2, 16)
                loss_num_std = numpy.nan_to_num(loss_num_std)  # Replace nan with 0
                loss_pow_std = numpy.nan_to_num(loss_pow_std)  # Replace nan with 0

        # -- Convert to list (just in case...)

        apt = apt.tolist()
        accpt = accpt.tolist()
        accpt.append(accpt[0])
        del accpt[0]
        cent_ave = cent_ave.tolist()
        cent_rms = cent_rms.tolist()
        cent_max = cent_max.tolist()
        cent_min = cent_min.tolist()
        sig_ave = sig_ave.tolist()
        sig_rms = sig_rms.tolist()
        eps_ave = eps_ave.tolist()
        eps_rms = eps_rms.tolist()
        xmax = xmax.tolist()
        xmin = xmin.tolist()
        if Nptcl[0] > 0:
            loss_num = loss_num.tolist()
            loss_pow = loss_pow.tolist()
            loss_num_ave = loss_num_ave.tolist()
            loss_pow_ave = loss_pow_ave.tolist()
            loss_num_rms = loss_num_rms.tolist()
            loss_pow_rms = loss_pow_rms.tolist()
            den = den.tolist()
            den_pow = den_pow.tolist()

        if Nrun > 1:
            cent_std = cent_std.tolist()
            sig_std = sig_std.tolist()
            eps_std = eps_std.tolist()
            if Nptcl[0] > 0:
                loss_num_std = loss_num_std.tolist()
                loss_pow_std = loss_pow_std.tolist()
        # -- Outputs
        self.idx_elem = idx_elem
        self.s = s
        self.apt = apt
        self.accpt = accpt
        self.Nptcl = Nptcl
        self.Ibeam = Ibeam
        self.Nrun = Nrun
        self.Nstep = Nstep
        if Nrun == 1:
            self.cent = cent_ave
            self.sig = sig_ave
            self.eps = eps_ave
            self.xmax = xmax
            self.xmin = xmin
            if Nptcl[0] > 0:
                self.loss_num = loss_num[0]
                self.loss_pow = loss_pow[0]
                self.den = den
                self.den_pow = den_pow
        else:
            self.cent_ave = cent_ave
            self.cent_rms = cent_rms
            self.cent_std = cent_std
            self.cent_max = cent_max
            self.cent_min = cent_min
            self.sig_ave = sig_ave
            self.sig_rms = sig_rms
            self.sig_std = sig_std
            self.eps_ave = eps_ave
            self.eps_rms = eps_rms
            self.eps_std = eps_std
            self.xmax = xmax
            self.xmin = xmin
            if Nptcl[0] > 0:
                self.loss_num = loss_num
                self.loss_pow = loss_pow
                self.loss_num_ave = loss_num_ave
                self.loss_pow_ave = loss_pow_ave
                self.loss_num_rms = loss_num_rms
                self.loss_pow_rms = loss_pow_rms
                self.loss_num_std = loss_num_std
                self.loss_pow_std = loss_pow_std
                self.loss_num_max = loss_num_max
                self.loss_pow_max = loss_pow_max
                self.loss_num_min = loss_num_min
                self.loss_pow_min = loss_pow_min
                self.den = den
                self.den_pow = den_pow

        # -- Option outputs

        self.idx_4_elem_end = [
            len(idx_elem) - 1 - idx_elem[::-1].index(i) for i in list(set(idx_elem))
        ]


# -------- Functions

# ---- Dist related


def x2dst(x, mass, freq, Ibeam, path_file="part_dtl1_new.dst"):
    """
        Output a TraceWin's .dst file from x and etc.
        Input: x (Nptcl,6)

        For binary characters see https://docs.python.org/2/library/struct.html

        2014.10.03
    """

    fname = open(path_file, "w")
    out = pack("b", 125)
    out += pack("b", 100)
    out += pack("i", len(x))
    out += pack("d", Ibeam)
    out += pack("d", freq)
    out += pack("b", 125)  # Typo in the manual !!!!
    x = list(chain(*x))  # Flatten x
    for x_i in x:
        out += pack("d", x_i)
    out += pack("d", mass)
    fname.write(out + "\n")
    fname.close()
def plt2x(path_file):
        Extract x and etc from a TraceWin's binary .plt file.
        The units are (cm,rad,cm,rad,rad,MeV,loss).