## Automatically adapted for numpy.oldnumeric Jul 23, 2007 by 

import unittest
try:
    from scenario.interpolators import *
except:
    print "ERROR: failed to import scenario.interpolators"
    
import types
from mglutil.math.rotax import rotax
import math
from math import *
import numpy.oldnumeric as Numeric


def sameIntSequences(list1, list2):
    for a,b in zip(list1,list2):
        if a!=b:
            return False
    return True

def sameFloatSequences(list1, list2):
    for a,b in zip(list1,list2):
        if "%.2f"%a !=  "%.2f"%b:
            print  "%.2f"%a , "%.2f"%b
            return False
    return True

def adding(x):
    return x+x
     
def withargs(x,y,height,width):
    z=x*y*height*width
    return z
    
def simple(x):
    return x

def sinfunc(mod,x):
    return round(sin(mod))
    
def cosfunc(x):
    return cos(x)

class InterpolatorTest(unittest.TestCase):


    def test_IntScalarInterpolators(self):
        
        ip = IntScalarInterpolator(0, 10)
        val = ip.getValue(0.0)
        self.assertEqual(val, ip.firstVal)
        val = ip.getValue(1.0)
        self.assertEqual( val, ip.lastVal)
        self.assertEqual( ip.getValue(0.3), 3)
        self.assertEqual( ip.getValue(0.34), 3)
        self.assertEqual( ip.getValue(0.38), 4)
        #test configure()
        ip.configure(firstVal = 25, lastVal = 30)
        val = ip.getValue(0.0)
        self.assertEqual(val, 25)
        val = ip.getValue(1.0)
        self.assertEqual( val, 30)
        self.assertEqual(ip.valueRange, 5)
        self.assertEqual( ip.getValue(0.5), 28)
        

    def test_FloatScalarInterpolator(self):
        ip = FloatScalarInterpolator(0., 10.)
        val = ip.getValue(0.0)
        self.assertEqual(val, ip.firstVal)
        val = ip.getValue(1.0)
        self.assertEqual( val, ip.lastVal)
        val = "%.2f" % ip.getValue(0.35)
        self.assertEqual(val, "3.50")
        val = "%.2f" % ip.getValue(0.5)
        self.assertEqual(val, "5.00")
        #testconfigure()
        ip.configure(firstVal = 25.5, lastVal = 30.5)
        val = "%.2f" % ip.getValue(0.0)
        self.assertEqual(val, "25.50")
        val = "%.2f" % ip.getValue(1.0)
        self.assertEqual( val, "30.50")
        self.assertEqual("%.2f" % ip.valueRange, "5.00")
        val =  "%.2f" % ip.getValue(0.56)
        self.assertEqual( val, "28.30")


    def test_IntVectorInterpolator(self):
        starts =( 0, 10,  -5, -15 )
        ends =  (10,  0, -15,  -5)
        ip = IntVectorInterpolator(starts, ends)
        val = ip.getValue(0.0)
        self.assertEqual( len(val)==4,True)
        self.assertEqual( sameIntSequences(val, ip.firstVal),True)
                        
        val = ip.getValue(1.0)
        self.assertEqual( sameIntSequences(val, ip.lastVal),True)
    
        val = ip.getValue(0.5)
        self.assertEqual( sameIntSequences(val, [5, 5, -10, -10]),True)
        val = ip.getValue(0.32)
        self.assertEqual( sameIntSequences(val, [3, 7, -8, -12]),True)
    
        val = ip.getValue(0.38)
        self.assertEqual( sameIntSequences(val, [4, 6, -9, -11]),True)

        # test configure()
        ends = [10, 20, 0, 0]
        ip.configure(lastVal = Numeric.array(ends ))
        self.assertEqual( sameIntSequences(ip.valueRange, [10, 10, 5, 15]), True)
        self.assertEqual( sameIntSequences(ip.getValue(0.5), [5, 15, -3, -8]), True)
        self.assertEqual( sameIntSequences(ip.getValue(0.32),[3, 13, -3, -10]), True)
        self.assertEqual( sameIntSequences(ip.getValue(0.38), [4, 14, -3, -9]), True)


    def test_IntVarScalarInterpolator(self):
        ip = IntVarScalarInterpolator (1, [2,4,6])
        val = ip.getValue(0.5)
        self.assertEqual(sameIntSequences(val, [2, 3, 4]), True)
        val = ip.getValue(0.6)
        self.assertEqual(sameIntSequences(val,[2, 3, 4]), True)
        val = ip.getValue(0.67)
        self.assertEqual(sameIntSequences(val, [2, 3, 4]), True)  
        # configure()
        ip.configure(firstVal = [2,4,6], lastVal = 1)
        val = ip.getValue(0.5)
        self.assertEqual(sameIntSequences(val, [2, 3, 4]), True)
        val = ip.getValue(0.6)
        self.assertEqual(sameIntSequences(val, [1, 2, 3]), True)
        val = ip.getValue(0.67)
        self.assertEqual(sameIntSequences(val, [1, 2, 3]), True)
        self.assertEqual(ip.nbvar , 3)


    def test_FloatVarScalarInterpolator(self):
        ip = FloatVarScalarInterpolator (1, [2,4,6])
        val = ip.getValue(0.5)
        self.assertEqual(sameFloatSequences(val, [1.5, 2.5, 3.5]), True)
        val = ip.getValue(0.6)
        self.assertEqual(sameFloatSequences(val,[1.60, 2.799, 4.0]), True)
        val = ip.getValue(0.67)
        self.assertEqual(sameFloatSequences(val, [1.6699, 3.010, 4.34999]), True)
        
        # configure()
        ip.configure(firstVal = [2,4,6], lastVal = 1)
        val = ip.getValue(0.5)
        self.assertEqual(sameFloatSequences(val, [1.5, 2.5, 3.5]), True)
        val = ip.getValue(0.6)
        self.assertEqual(sameFloatSequences(val, [1.399, 2.200, 3.0]), True)
        val = ip.getValue(0.67)
        self.assertEqual(sameFloatSequences(val,[1.330, 1.9899, 2.6499] ), True)
        self.assertEqual(ip.nbvar , 3)

        
    def test_FloatVectorInterpolator(self):
        starts =( 0.0, 10.0,  -5.0, -15.0 )
        ends =  (10.0,  0.0, -15.0,  -5.0 )
        ip = FloatVectorInterpolator(starts, ends)
        val = ip.getValue(0.0)
        self.assertEqual( len(val)==4,True)
        self.assertEqual( sameFloatSequences(val, ip.firstVal),True)
                        
        val = ip.getValue(1.0)
        self.assertEqual( sameFloatSequences(val, ip.lastVal),True)
    
        val = ip.getValue(0.5)
        self.assertEqual( sameFloatSequences(val, [5.0, 5.0, -10.0, -10.0]),True)
       
        val = ip.getValue(0.32)
        self.assertEqual( sameFloatSequences(val, [3.20, 6.799, -8.199, -11.80]), True)
    
        val = ip.getValue(0.38)
        self.assertEqual( sameFloatSequences(val, [3.799, 6.20, -8.80, -11.199]),True)
        # test configure()
        ends = [10, 20, 0, 0]
        ip.configure(lastVal = Numeric.array(ends ))
        self.assertEqual( sameFloatSequences(ip.valueRange, [10.0, 10.0, 5.0, 15.0]), True)
        self.assertEqual( sameFloatSequences(ip.getValue(0.5), [5, 15, -2.5, -7.5]), True)
        self.assertEqual( sameFloatSequences(ip.getValue(0.32),[3.2, 13.1999, -3.3999, -10.1999]), True)
        self.assertEqual( sameFloatSequences(ip.getValue(0.38), [3.7999, 13.80, -3.10, -9.30]), True)
        self.assertEqual(ip.nbvar , 4)


    def test_VarVectorInterpolator(self):

        val1 = [[1,1,1], ]
        val2 = [[6,6,6], [8,8,8], [10,10,10]] 
        ip = VarVectorInterpolator(val1, val2)
        val= ip.getValue(0.)

        self.assertEqual( sameIntSequences(val,val1), True)
        val = ip.getValue(1.)
        self.assertEqual( sameIntSequences(val,val2), True)
        val = ip.getValue(0.5)
        self.assertEqual(type(val), Numeric.ArrayType)
        self.assertEqual(sameFloatSequences(val.ravel() , Numeric.array([[ 3.5,  3.5,  3.5],
                                                                    [ 4.5,  4.5,  4.5],
                                                                    [ 5.5,  5.5,  5.5]], 'f').ravel()), True)
        val = ip.getValue(0.6)
        self.assertEqual(type(val), Numeric.ArrayType)
        self.assertEqual(sameFloatSequences(val.ravel() , Numeric.array([[ 4.0,  4.0,  4.0 ],
                                                                    [ 5.2,  5.2,  5.2,],
                                                                    [ 6.4,  6.4,  6.4,]], 'f').ravel()), True)
        val = ip.getValue(0.67)
        self.assertEqual(type(val), Numeric.ArrayType)
        self.assertEqual(sameFloatSequences(val.ravel() , Numeric.array([[ 4.35,  4.35,  4.35],
                                                                    [ 5.69,  5.69,  5.69],
                                                                    [ 7.03,  7.03,  7.03]], 'f').ravel()), True)
        
        # configure()
        ip.configure(firstVal = val2, lastVal = val1)
        val = ip.getValue(0.5)
        self.assertEqual(type(val), Numeric.ArrayType)
        self.assertEqual(sameFloatSequences(val.ravel(), Numeric.array([[3.5, 3.5, 3.5],
                                                                     [4.5, 4.5, 4.5],
                                                                     [5.5, 5.5, 5.5]], 'f').ravel()), True)

        val = ip.getValue(0.6)
        self.assertEqual(type(val), Numeric.ArrayType)
        self.assertEqual(sameFloatSequences(val.ravel(), Numeric.array([[3.0, 3.0, 3.0],
                                                                     [3.8, 3.8, 3.8],
                                                                     [4.6, 4.6, 4.6]], 'f').ravel()), True)

        val = ip.getValue(0.67)
        self.assertEqual(type(val), Numeric.ArrayType)
        self.assertEqual(sameFloatSequences(val.ravel() , Numeric.array([[2.65, 2.65, 2.65],
                                                                      [3.31, 3.31, 3.31],
                                                                      [3.97, 3.97, 3.97]], 'f').ravel()), True) 
        
        
    def test_RotationInterpolator(self):
        M1 = Numeric.identity(4).astype('f')
        M2 =  rotax( [0.,0.,0.], [0.,1.,0.], math.pi/2)
        ip = RotationInterpolator(M1.ravel(), M2.ravel())
        #print "range:", ip.valueRange
        val = ip.getValue(0.3)
        # not sure how to validate the value (val) 
        #val1 = map(lambda x, y: x+0.3*y , M1, M2-M1)
        ip1 = FloatVectorInterpolator(M1.ravel(), M2.ravel())
        val1 = ip1.getValue(0.3)
        print val
        print val1
        
        # test configure()
        
        M3 =  rotax( [0.,0.,0.], [0.,1.,0.], math.pi).ravel()
        ip.configure(lastVal = M3)
        val = ip.getValue(0.3)
        #print val


    def test_FunctionInterpolator(self):
        func1 = FunctionInterpolator(function = simple)
        val = func1.getValue(0.5)
        self.assertEqual(val,0.5)

        func1 = FunctionInterpolator(function = adding)
        val = func1.getValue(0.5)
        self.assertEqual(val,1)

        #with posArgs,namedArgs
        func1 = FunctionInterpolator( function = (withargs,(2,),{'height':340,'width':20}))
        val = func1.getValue(0.5)
        self.assertEqual(val,6800.0)

        # test configure()
        func1.configure(firstVal = 5.0, lastVal = 15.0)
        val0 = func1.getValue(0)
        self.assertEqual(val0, 68000.0)
        val1 = func1.getValue(1)
        self.assertEqual(val1,204000.0)
        val = func1.getValue(0.5)
        self.assertEqual(val, 136000.0)
        
        func1.configure(function = simple)
        val0 = func1.getValue(0)
        self.assertEqual(val0,5.0)
        val1 = func1.getValue(1)
        self.assertEqual(val1, 15.0)
        val = func1.getValue(0.5)
        self.assertEqual(val, 10.0)

       

    def test_CompositeInterpolator(self):
        ip = CompositeInterpolator(interpolators = (FloatVectorInterpolator, FloatScalarInterpolator),
                                   firstVal =([1,1,1], 1,), lastVal = ([10,10,10], 10))
        val= ip.getValue(0.5)
        self.assertEqual(len(val) , 2)
        
        self.assertEqual(sameFloatSequences(val[0], [5.5, 5.5, 5.5]), True)
        self.assertEqual(val[1], 5.5)
        val = ip.getValue(0.7)
        self.assertEqual(len(val) , 2)
        self.assertEqual(sameFloatSequences(val[0], [7.299, 7.299, 7.299]), True)
        self.assertEqual(sameFloatSequences([val[1]], [7.299,]), True)

        #configure()
        ip.configure(interpolators = (FloatScalarInterpolator, IntScalarInterpolator,FloatVectorInterpolator),
                     firstVal = (1.0, 1,[1,1,1]), lastVal = (10.0, 10, [10,10,10]) )
        self.assertEqual(len(ip.interpolators) , 3)
        val= ip.getValue(0.5)
        self.assertEqual(len(val) , 3)
        self.assertEqual(sameFloatSequences([val[0],] ,[5.5,]), True)
        self.assertEqual(val[1], 6)
        self.assertEqual(sameFloatSequences(val[2], [5.5, 5.5, 5.5]), True)
        
        val = ip.getValue(0.7)
        self.assertEqual(len(val) , 3)
        self.assertEqual(sameFloatSequences([val[0],] ,[7.2999]), True)
        self.assertEqual(val[1], 7)
        self.assertEqual(sameFloatSequences(val[2], [7.299, 7.299, 7.299]), True)
                                   
if __name__ == '__main__':
    unittest.main()
