From a85c8f405e3c84c9324f43a429d2ff4188e93bdf Mon Sep 17 00:00:00 2001
From: Yngve Levinsen <yngve.levinsen@esss.se>
Date: Thu, 16 Aug 2018 14:37:56 +0200
Subject: [PATCH] edits in TraceWin.field_map using meshgrid for the data added
 interpolation function

---
 ess/TraceWin.py | 59 ++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 54 insertions(+), 5 deletions(-)

diff --git a/ess/TraceWin.py b/ess/TraceWin.py
index fa66906..dc44acf 100644
--- a/ess/TraceWin.py
+++ b/ess/TraceWin.py
@@ -1125,10 +1125,12 @@ class field_map:
             raise ValueError("Cannot find file {}".format(filename))
         fin = open(filename, 'r')
         l = fin.readline().split()
+        self.header = []
         self.start = []
         self.end = []
         numindexes = []
         while len(l) > 1:
+            [self.header.append(float(i)) for i in l]
             numindexes.append(int(l[0]) + 1)
             if len(l) == 2:
                 self.start.append(0.0)
@@ -1138,15 +1140,61 @@ class field_map:
                 self.end.append(float(l[2]))
             l = fin.readline().split()
 
-        self.z = numpy.arange(self.start[0], self.end[0], (self.end[0] - self.start[0]) / numindexes[0])
-        if len(self.start) > 1:
-            self.y = numpy.arange(self.start[1], self.end[1], (self.end[1] - self.start[1]) / numindexes[1])
-        if len(self.start) > 2:
-            self.x = numpy.arange(self.start[2], self.end[2], (self.end[2] - self.start[2]) / (numindexes[2]))
+        if len(self.start) == 1:
+            self.z = numpy.mgrid[self.start[0]:self.end[0]:numindexes[0]*1j]
+            print(new_z,self.z)
+        elif len(self.start) == 2:
+            self.z, self.x = numpy.mgrid[self.start[0]:self.end[0]:numindexes[0]*1j,
+                                         self.start[1]:self.end[1]:numindexes[1]*1j]
+        elif len(self.start) == 3:
+            self.z, self.x, self.y = numpy.mgrid[self.start[0]:self.end[0]:numindexes[0]*1j, 
+                                                 self.start[1]:self.end[1]:numindexes[1]*1j, 
+                                                 self.start[2]:self.end[2]:numindexes[2]*1j]
 
         self.norm = float(l[0])
+        self.header.append(self.norm)
         self.map = numpy.loadtxt(fin).reshape(numindexes)
 
+    def get_flat_fieldmap(self):
+        totmapshape = 1
+        for i in self.map.shape:
+            totmapshape *= i
+        return self.map.reshape(totmapshape)
+
+    def interpolate(self, npoints: tuple, method='cubic'):
+        '''
+        Interpolate the map into a new mesh
+        Each value should be an integer with the number of mesh points in each dimension
+        intervals should be tuple-like with same number of elements
+        as the map dimension, e.g. [0.8,0.8] for 2D
+        Can also be a float if you want same interpolation factor in all planes
+
+        method can be 'linear', 'nearest' or 'cubic'
+        '''
+        import numpy
+        from scipy.interpolate import griddata
+
+        values=self.map.flatten()
+
+        if len(self.start) == 1:
+            points=self.z[:]
+            self.z = numpy.mgrid[self.start[0]:self.end[0]:npoints[0]*1j]
+            self.map=griddata(points,values,self.z)
+        if len(self.start) == 2:
+            points=numpy.array([self.z.flatten(), self.x.flatten()]).transpose()
+            self.z, self.x = numpy.mgrid[self.start[0]:self.end[0]:npoints[0]*1j,
+                                         self.start[1]:self.end[1]:npoints[1]*1j]
+            self.map=griddata(points,values,(self.z,self.x))
+        if len(self.start) == 3:
+            points=numpy.array([self.z.flatten(), self.x.flatten(), self.y.flatten()]).transpose()
+            self.z, self.x, self.y = numpy.mgrid[self.start[0]:self.end[0]:npoints[0]*1j,
+                                                 self.start[1]:self.end[1]:npoints[1]*1j,
+                                                 self.start[2]:self.end[2]:npoints[2]*1j]
+            self.map=griddata(points,values,(self.z,self.x,self.y))
+            self.header[0]=npoints[0]-1
+            self.header[2]=npoints[1]-1
+            self.header[5]=npoints[2]-1
+
     def savemap(self, filename):
         fout = open(filename, 'w')
         for n, s in zip(self.map.shape, self.size):
@@ -1158,3 +1206,4 @@ class field_map:
         data = self.map.reshape(totmapshape)
         for j in data:
             fout.write('{}\n'.format(j))
+
-- 
GitLab