Program Listing for File unrol_radial.py

Return to documentation for file (/Users/robertshaw/devfiles/libecpint/src/generated/radial/unrol_radial.py)

MAX_UNROL_AM = 4

# DO NOT EDIT BELOW HERE

from sympy import *

class Qijk:
    def __init__(self, Ival = 0, Jval = 0, Kval = 0):
        self.i = Ival
        self.j = Jval
        self.k = Kval
        self.size = 2*self.i + self.j + 1
        self.start = self.k-self.i-self.j
        self.end = self.k+self.i
        self.subq = []
        self.terms = []
        self.bases = []
        self.f = []
        self.ga = []
        self.gb = []
        self.h = []
        for i in range(self.size):
            self.bases.append([])

    def print(self):
        print("Q", self.i, self.j, self.k)
        for i in range(len(self.bases)):
            print(i+self.start, ":")
            for j in range(len(self.bases[i])):
                print(self.bases[i][j])

    def print_simple(self):
         print("Q", self.i, self.j, self.k)
         for i in range(len(self.bases)):
            print(i+self.start, ":", simplify(self.bases[i]))

    def print_fgh(self):
        print("Q", self.i, self.j, self.k)
        for i in range(len(self.f)):
            ix = 2*i+self.start
            if ix > 0:
                print("F", ix, ":", simplify(self.f[i]))
        for i in range(len(self.gb)):
            ix = 2*i + self.start + 1
            if ix > 0:
                print("GB", ix, ":", simplify(self.gb[i]))
        for i in range(len(self.ga)):
            ix = 2*i + self.start + 1
            if ix > 0:
                print("GA", ix, ":", simplify(self.ga[i]))
        for i in range(len(self.h)):
            ix = 2*i+self.start + 2
            if ix > 0:
                print("H", ix, ":", simplify(self.h[i]))

    def write_code(self, f):
        print("\t\t\t\t\t\t\t\t\tcase", self.i*10000+self.j*100+self.k, ": {", file=f)

        for i in range(len(self.f)):
            ix = 2*i+self.start
            if ix == 2:
                simp = simplify(self.f[i])
                if simp!= 0:
                    print("\t\t\t\t\t\t\t\t\t\tresult = (", simp, ") * values[0];", file=f)
            elif ix > 2:
                simp = simplify(self.f[i])
                if simp!= 0:
                    print("\t\t\t\t\t\t\t\t\t\tresult += (", simp, ") * values[", ix-2, "];", file=f)

        for i in range(len(self.gb)):
            ix = 2*i + self.start + 1
            if ix == 1:
                simp = simplify(self.gb[i])
                if simp!= 0:
                    print("\t\t\t\t\t\t\t\t\t\tresult += (", simp, ") * G1B;", file=f)
            elif ix > 1:
                simp = simplify(self.gb[i])
                if simp!= 0:
                    print("\t\t\t\t\t\t\t\t\t\tresult += (", simp, ") * values[", ix-2, "];", file=f)

        for i in range(len(self.ga)):
            ix = 2*i + self.start + 1
            if ix == 1:
                simp = simplify(self.ga[i])
                if simp != 0:
                    print("\t\t\t\t\t\t\t\t\t\tresult += (", simp, ") * G1A;", file=f)

        for i in range(len(self.h)):
            ix = 2*i+self.start + 2
            if ix == 2:
                simp = simplify(self.h[i])
                if simp!= 0:
                    print("\t\t\t\t\t\t\t\t\t\tresult += (", simp, ") * H2;", file=f)

        print("\t\t\t\t\t\t\t\t\t\tbreak;", file=f)
        print("\t\t\t\t\t\t\t\t\t}", file=f)

    def simplify(self):
        simple_bases = []
        for i in range(len(self.bases)):
            x = Symbol('x')
            y = Symbol('y')
            z = Symbol('z')
            z = 0
            for j in range(len(self.bases[i])):
                z = z + parse(self.bases[i][j])
            simple_bases.append(z)
        self.bases = simple_bases

    def sort(self):
        for i in range(len(self.bases)):
            if i % 2 == 0:
                self.f.append(self.bases[i])
            else:
                self.gb.append(self.bases[i])

    def eliminate(self):
        x = Symbol('x')
        y = Symbol('y')
        z = Symbol('z')
        p = Symbol('p')
        z = 0

        if self.start < 1:
            if self.end < 1:
                self.bases.append(z)
            if self.end < 2:
                self.bases.append(z)

            w = Symbol('w')

            N = self.start
            ix = 0
            gaix = -1
            hix  = -1

            while (N < 1):
                z = self.bases[ix]
                w = self.bases[ix+2]
                w = w + (2 * p / (N-1))*z
                self.bases[ix+2] = w

                w = self.bases[ix+1]
                w = w - (2 * y / (N-1))*z
                self.bases[ix+1] = w

                if ix % 2 == 0:
                    if gaix > -1:
                        w = self.ga[gaix]
                        w = w - (2*x / (N-1))*z
                        self.ga[gaix] = w
                    else:
                        w = -(2*x / (N-1))*z
                        self.ga.append(w)
                        gaix += 1

                    if hix > -1:
                        z = self.h[hix]
                        w = (2 * p / (N-1))*z
                        self.h.append(w)
                        hix += 1

                        w = self.ga[gaix]
                        w = w - (2 * y / (N-1))*z
                        self.ga[gaix] = w

                        w = self.bases[ix+1]
                        w = w - (2 * x / (N-1))*z
                        self.bases[ix+1] = w
                else:
                    if hix > -1:
                        w = self.h[hix]
                        w = w - (2*x / (N-1))*z
                        self.h[hix] = w
                    else:
                        w = - (2*x / (N-1))*z
                        self.h.append(w)
                        hix += 1

                    if gaix > -1:
                        z = self.ga[gaix]
                        w = (2 * p /(N-1)) * z
                        self.ga.append(w)
                        gaix += 1

                        w = self.h[hix]
                        w = w - (2 * y / (N-1))*z
                        self.h[hix] = w

                        w = self.bases[ix+1]
                        w = w - (2*x / (N-1))*z
                        self.bases[ix+1] = w

                N += 1
                ix += 1

def parse(term):
    x = Symbol('x')
    y = Symbol('y')
    p = Symbol('p')
    bits = term.split(',')
    z = Symbol('z')
    z = 1
    for bit in bits:
        bi = bit[:2]
        if bi == "mu":
            ix = 2
            i = 0
            j = 0
            k = 0

            I = bit[ix]
            if I == "-":
                ix += 1
                i = -int(bit[ix])
            else:
                i = int(I)

            ix += 1
            J = bit[ix]
            if J == "-":
                ix += 1
                j = -int(bit[ix])
            else:
                j = int(J)

            ix += 1
            K = bit[ix]
            if K == "-":
                ix += 1
                k = -int(bit[ix])
            else:
                k = int(K)

            z = z * (2 + j - i - k)/(2*x)
        elif bi == "nu":
            z = z * (-y/x)
        elif bi == "xi":
            z = z * p/x
        elif bi == "rh":
            j = int(bit[3])
            z = z * (1 - 2*j)/(2*y)
        elif bi == "om":
            z = z * -1 / (2*y)
    return z

def unrol(q):
    if (q.i == 0 and q.j == 0):
        return
    elif (q.i > 0):
        q1 = Qijk(Ival = q.i-1, Jval = q.j, Kval = q.k-1)
        q.subq.append(q1)
        q.terms.append("mu" + str(q.i) + str(q.j) + str(q.k))

        q2 = Qijk(Ival = q.i-1, Jval = q.j-1, Kval = q.k)
        q.subq.append(q2)
        q.terms.append("nu")

        q3 = Qijk(Ival = q.i-1, Jval = q.j, Kval = q.k+1)
        q.subq.append(q3)
        q.terms.append("xi")
    elif(q.j > 1):
        q1 = Qijk(Ival = 0, Jval = q.j-2, Kval = q.k)
        q.subq.append(q1)
        q.terms.append("sigma")

        q2 = Qijk(Ival = 0, Jval = q.j-1, Kval = q.k-1)
        q.subq.append(q2)
        q.terms.append("rho" + str(q.j))
    else:
        q1 = Qijk(Ival = 0, Jval = 0, Kval = q.k)
        q.subq.append(q1)
        q.terms.append("ups")

        q2 = Qijk(Ival = 0, Jval = 0, Kval = q.k-1)
        q.subq.append(q2)
        q.terms.append("om")

    for i in range(len(q.subq)):
        unrol(q.subq[i])

    return

def collect(q, Q, term):
    if (q.i == 0 and q.j == 0):
        Q.bases[q.k-Q.start].append(term)
    else:
        for i in range(len(q.subq)):
            collect(q.subq[i], Q, term + q.terms[i] + ",")

def algebraic_unrol(i, j, k):
    q = Qijk(Ival = i, Jval = j, Kval = k)
    unrol(q)
    collect(q, q, "")
    q.simplify()
    return q


f = open('radial_gen.part2', 'w')
print("", file=f)
for j in range(MAX_UNROL_AM+1):
    for i in range(j+1):
        for k in range(1, 3*MAX_UNROL_AM+1-i-j):
            if (i + j + k) % 2 == 0:
                q = algebraic_unrol(i, j, k)
                q.eliminate()
                q.sort()
                q.write_code(f)
                print("", file=f)