from matplotlib.backends.backend_wxagg import Toolbar, FigureCanvasWxAgg
from matplotlib.figure import Figure
import comedi #@UnresolvedImport
import matplotlib
import numpy
import os
import struct
import sys
import threading
import wx.lib.newevent
import time
matplotlib.use("WXAgg")
matplotlib.interactive(True)

ID_STARTBUTTON = 101
ID_STOPBUTTON = 102
ID_CLEARBUTTON = 103
ID_OKBUTTON = 104
ID_CONFIGUREBUTTON = 105

#default parameters
DEVICE = '/dev/comedi0'

#global variables
BUFFER_SIZE = 1000
PLOT_BUFFER_SIZE = 5000
DEFAULT_LINE_WIDTH = 1.0
MARGIN = 0.1
DEBUG = True

def enum(**enums):
    return type('Enum', (), enums)

def deinterleaveData(inData, nChan):
    '''
    takes a linear array with data points interleaved 
    [a1,b1,c1,a2,b2,c2,....aN-1,bN-1,cN-1,aN,bN]
    and return a (nChan,N) numpy.array with the data deinterleaved
    [[a1,a2,a3,...aN-1],
     [b1,b2,b3,...bN-1],
     [c1,c2,c3,...cN-1]]
    
    if the length of the input array was not a multiple of nChan, 
    the remaining points are returned in remainData.
    Otherwise, remainData is an empty array
    '''
    n = len(inData)
    nToKeep = n - (n % nChan)
    outData = numpy.array(inData[:nToKeep], dtype = int).reshape((-1, nChan)).transpose()
    remainData = inData[nToKeep:]
    return outData, remainData

def convertInVolts(inData, inMaxChanValues, inChanRanges):
    '''
    converts the int values obtained from comedi to a physical value
    arguments:
    inData: a (nChan,N) numpy array
    inMaxChanValues: a (nChan,1) array containing information about the ranges of each channel in comedi_range structure
    inChanRanges: a (nChan,1) array containing the id of the chanRange used for each channel.
    The actual chanRange will be obtained from comedi using comedi_get_range()
    '''
    nChan, _ = inData.shape
    out = numpy.array(inData, dtype = float)
    for i in range(nChan):
        out[i] = out[i] * (
                    float(inChanRanges[i].max-inChanRanges[i].min) / float(inMaxChanValues[i])) + \
                    float(inChanRanges[i].min)
    return out

triggerTypes = enum(AUTO = 0, RISINGTHRESHOLD = 1)

class TriggerPlot(matplotlib.axes.Axes):
    def __init__(self, fig, rect,
                 axisbg = None, # defaults to rc axes.facecolor
                 frameon = True,
                 sharex = None, # use Axes instance's xaxis info
                 sharey = None, # use Axes instance's yaxis info
                 label = '',
                 xscale = None,
                 yscale = None):
        matplotlib.axes.Axes.__init__(self, fig, rect, axisbg, frameon, sharex, sharey, label, xscale, yscale)
        self.remanence = 5
        self.trigger = triggerTypes.RISINGTHRESHOLD
        self.__bufferSize = PLOT_BUFFER_SIZE
        self.__peekBufferSize = PLOT_BUFFER_SIZE*(self.remanence+1)
        self.__peekBuffer = numpy.zeros(self.__peekBufferSize)
        self.__peekBufferPos = 0
        self.lines = self.plot([], '-')
        self.autoDefineTrigger = True
        self.__triggerLevelAutoDefined = False
        self.triggerLevel = 0.95
        self.__relTriggerLevel = 0.75
    
    def cla(self):
        matplotlib.axes.Axes.cla(self)
        self.lines = self.plot([], '-')
    
    def __addDataToPeekBuffer(self, inData):
        nbPoints = len(inData)
        maxPoints = self.__peekBufferSize-self.__peekBufferPos
        if nbPoints <= maxPoints:
            maxPoints = nbPoints
        self.__peekBuffer[self.__peekBufferPos:(self.__peekBufferPos+maxPoints)] = inData[:maxPoints]
        self.__peekBufferPos += maxPoints
        if self.__peekBufferPos>=self.__peekBufferSize-1:
            nbPoints = len(inData[maxPoints:])
            self.__peekBuffer[:nbPoints] = inData[maxPoints:]
            self.__peekBufferPos = nbPoints
        
    def processNewData(self, inData):
        self.__addDataToPeekBuffer(inData)
        
        #if DEBUG: print "#### in processNewData() ####"
        nbPoints = len(inData)
        yData = self.lines[-1].get_ydata()
        currPos = len(yData)
        #if DEBUG: print "received %d data points. current position in buffer: %s" % (nbPoints, currPos)
        maxPoints = self.__bufferSize - currPos
        if nbPoints <= maxPoints:
            maxPoints = nbPoints
        #if DEBUG: print "will add %d points to the buffer" % (maxPoints)
        yData = numpy.append(yData, inData[:maxPoints])
        self.lines[-1].set_data(numpy.linspace(1, len(yData), len(yData)), yData)
        currPos += maxPoints
        
        #if DEBUG: print "current position in the buffer is now %d" % (currPos)
        if currPos > self.__bufferSize - 1:
            #we've filled the whole buffer, we need to add a new line
            #if DEBUG: print "wrapping around..."
            
            if len(self.lines) >= self.remanence + 1:
                #if DEBUG: print "reached max number of lines (%d)" % (self.remanence + 1)
                self.lines.pop(0)
            
            #nPointsLeft = len(inData[maxPoints:])
            #if DEBUG: print "we have %s points left to add to a new line" % nPointsLeft
            self.__autoDefineThreshold()
            ret = self.__waitForTrigger(inData[maxPoints:])
            if len(ret)>0:
                line = matplotlib.lines.Line2D(numpy.linspace(1,
                                                              len(ret),
                                                              len(ret)),
                                               ret)
                #if DEBUG: print "adding line %s" % (line)
                self.add_line(line)
#            if DEBUG:
#                print "current lines: %d" % (len(self.lines))
#                print self.lines
            self.__updateLineColor()
            
    def __updateLineColor(self):
        for i in range(len(self.lines)):
            self.lines[i].set_alpha(0.1 + i * (1.0 / len(self.lines)))
            self.lines[i].set_lw(DEFAULT_LINE_WIDTH)
        self.lines[-1].set_lw(2.0*self.lines[-1].get_lw())
        
    def setThreshold(self, inValue):
        self.triggerLevel = inValue
    
    def __autoDefineThreshold(self):
        if self.autoDefineTrigger:
            #if DEBUG: print "trying to determine threshold automatically"
            minValue = self.__peekBuffer.min()
            maxValue = self.__peekBuffer.max()
            #if DEBUG: print "peeking into data [%f-%f]"%(minValue,maxValue)
            if (minValue is not None) and (maxValue is not None):
                overallRange = (maxValue-minValue)
                self.triggerLevel = minValue + self.__relTriggerLevel * overallRange
                #if DEBUG: print "threshold defined at %f"%(self.triggerLevel)
                self.__triggerLevelAutoDefined = True
    
    def __waitForTrigger(self, inData):
        if self.trigger==triggerTypes.AUTO:
            return inData
        elif self.trigger==triggerTypes.RISINGTHRESHOLD:
            if (not self.autoDefineTrigger) or self.__triggerLevelAutoDefined:
                #if DEBUG: print "waiting for trigger level %f"%self.triggerLevel
                a, = numpy.where(inData>self.triggerLevel)
                if len(a)>0 and a[0]>0 and inData[a[0]-1]<=self.triggerLevel:
                    #if DEBUG: print "Threshold crossed at index %d. returning %d points "%(a[0],len(inData[a[0]:]))
                    return inData[a[0]:]
                else:
                    #if DEBUG: print "trashing %d points"%(len(inData))
                    return []
            else:
                #if DEBUG: print "threshold not determined yet..."
                return inData
        else:
            raise LookupError("Invalid Trigger type %s"%str(self.trigger))


# this creates an UpdateCanvasEvent even class that will get created
# and sent from the worker thread to the main thread (which is the
# only place that updates to the GUI should happen...lest there be crashes)
(UpdateCanvasEvent, EVT_UPDATE_CANVAS) = wx.lib.newevent.NewEvent()

class MainWindow(wx.Frame):
    def __init__(self):
        wx.Frame.__init__(self, None, -1, "Free Period Calculation")

        self.arrayLock = threading.Semaphore(1)
        self.canvasLock = threading.Semaphore(1)
        self.dev = None
        self.subdev = None
        self.fd = None
        self.nbChans = None
        self.helper = None
        self.maxChanValues = []
        self.chanRanges = []
            
        #create the preferences frame
        self.prefwindow = self.createPreferencesWindow()
        #get default configuration
        self.configuration = self.prefwindow.getConfiguration()
        self.nbChans = 2
        
        self.figure = Figure((5, 4), 75)
        self.canvas = FigureCanvasWxAgg(self, -1, self.figure)
        self.trigPlot1 = TriggerPlot(self.figure, (0.1, 0.5+(0.5-0.5*0.8), 0.8, 0.5*0.8))
        self.trigPlot2 = TriggerPlot(self.figure, (0.1, 0.0+(0.5-0.5*0.8), 0.8, 0.5*0.8))
        self.figure.add_axes(self.trigPlot1)
        self.figure.add_axes(self.trigPlot2)
        self.toolbar = Toolbar(self.canvas)
        self.toolbar.Realize()

        self.sizerHoriz = wx.BoxSizer(wx.HORIZONTAL)

        self.startButton = wx.Button(self, ID_STARTBUTTON, "Start")
        wx.EVT_BUTTON(self, ID_STARTBUTTON, self.pressStartButton);

        self.stopButton = wx.Button(self, ID_STOPBUTTON, "Stop")
        self.stopButton.Enable(0)
        wx.EVT_BUTTON(self, ID_STOPBUTTON, self.onStopButton);

        self.clearButton = wx.Button(self, ID_CLEARBUTTON, "Clear")
        wx.EVT_BUTTON(self, ID_CLEARBUTTON, self.onClearButton);

        self.prefButton = wx.Button(self, ID_CONFIGUREBUTTON, "Configure")
        wx.EVT_BUTTON(self, ID_CONFIGUREBUTTON, self.onConfigureButton)

        self.sizerHoriz.Add(self.startButton, 0, wx.ALL, 5)
        self.sizerHoriz.Add(self.stopButton, 0, wx.ALL, 5)
        self.sizerHoriz.Add(self.clearButton, 0, wx.ALL, 5)
        self.sizerHoriz.Add(self.prefButton, 0, wx.ALL, 5)

        # Use some sizers to see layout options
        self.sizerVert = wx.BoxSizer(wx.VERTICAL)
        self.sizerVert.Add(self.sizerHoriz, 0, wx.EXPAND)
        self.sizerVert.Add(self.canvas, 1, wx.EXPAND)
        self.sizerVert.Add(self.toolbar, 0, wx.GROW)

        #Bind the update canvas event to some function:
        self.Bind(EVT_UPDATE_CANVAS, self.onUpdateCanvas)
        
        self.Bind(wx.EVT_CLOSE, self.onCloseWindow)

        #Layout sizers
        self.SetSizer(self.sizerVert)
        self.SetAutoLayout(1)
        self.sizerVert.Fit(self)
        self.initialize()
        self.Show(1)

    def GetToolBar(self):
        return self.toolbar

    def pressStartButton(self, event):
        if DEBUG: print "Start button has been pressed"
        self.startButton.Enable(0)
        self.prefButton.Enable(0)
        self.configuration = self.prefwindow.getConfiguration()
        
        #configure the comedi device
        self.dev = comedi.comedi_open(str(self.configuration['device']))
        if not self.dev:
            print "FATAL ERROR: cannot open comedi device %s: %s" % \
                (self.configuration['device'], comedi.comedi_strerror(comedi.comedi_errno()))
            sys.exit(1)
        #and get an appropriate subdevice
        self.subdev = comedi.comedi_find_subdevice_by_type(self.dev, comedi.COMEDI_SUBD_AI, 0)
        if self.subdev < 0:
            print "FATAL ERROR: cannot find a suitable analog input subdevice: %s" % \
                comedi.comedi_strerror(comedi.comedi_errno())
            sys.exit(1)
        #get a file-descriptor for reading
        #get a file-descriptor for reading
        self.fd = comedi.comedi_fileno(self.dev)
        if self.fd <= 0: raise Exception("Error obtaining Comedi device file descriptor: %s" % (comedi.comedi_strerror(comedi.comedi_errno())))
        
        #create channel list
        myChanList = comedi.chanlist(self.nbChans)
        for i in range(self.nbChans):
            myChanList[i] = comedi.cr_pack(
                self.configuration['analog'][i][0], 
                self.configuration['analog'][i][1], 
                self.configuration['analog'][i][2])
        #create a command structure
        cmd = comedi.comedi_cmd_struct()
        ret = comedi.comedi_get_cmd_generic_timed(self.dev, 
                                                  self.subdev, 
                                                  cmd, 
                                                  self.nbChans, 
                                                  self.configuration['analog_period_ns'])
        if ret: raise Exception("Error comedi_get_cmd_generic failed")
        cmd.chanlist = myChanList # adjust for our particular context
        cmd.chanlist_len = self.nbChans
        cmd.scan_end_arg = self.nbChans
        cmd.stop_src = comedi.TRIG_NONE #never stop
        
        #test our comedi command a few times.
        ret = 0
        for i in range(2):
            ret = comedi.comedi_command_test(self.dev, cmd)
            if ret < 0: raise Exception("comedi_command_test failed: %s" % (comedi.comedi_strerror(comedi.comedi_errno())))

        #Start the command
        ret = comedi.comedi_command(self.dev, cmd)
        if ret <> 0: raise Exception("comedi_command failed... %s" % (comedi.comedi_strerror(comedi.comedi_errno())))
        
        for chanInfo in self.configuration['analog']:
            self.maxChanValues.append(comedi.comedi_get_maxdata(self.dev,self.subdev,chanInfo[0]))
            self.chanRanges.append(comedi.comedi_get_range(self.dev,self.subdev,chanInfo[0],chanInfo[1]))
            
        #start reader thread
        if DEBUG: print "creating reader thread"
        self.helper = readerThread(self, self.fd, self.nbChans)        
        self.helper.start()
        
        self.stopButton.Enable(1)
        if DEBUG: print "Start button complete."

    def onStopButton(self, event):
        self.stopButton.Enable(0)
        if self.helper is not None:
            self.helper.reading = False
            self.helper.join()
            if DEBUG: print "Background thread stopped."
            self.helper = None
        comedi.comedi_cancel(self.dev, self.subdev)
        self.startButton.Enable(1)
        self.prefButton.Enable(1)

    def onClearButton(self, event):
        self.arrayLock.acquire()
        self.currPos = 0
        self.trigPlot1.cla()
        self.trigPlot2.cla()
        self.arrayLock.release()
        evt = UpdateCanvasEvent()
        wx.PostEvent(self, evt)

    def onConfigureButton(self, event):
        self.prefwindow.Show(1)
        self.prefwindow.MakeModal(1)

    def initialize(self):
        self.onClearButton(None)
    
    def processNewData(self, inData):
        self.arrayLock.acquire()
        inData = convertInVolts(inData,self.maxChanValues,self.chanRanges)
        self.trigPlot1.processNewData(inData[0,:])
        self.trigPlot2.processNewData(inData[1,:])
        self.arrayLock.release()

    def onUpdateCanvas(self, evt):
        #if DEBUG: print "Updating the canvas..."
        self.canvasLock.acquire()
        self.canvas.draw()
        self.trigPlot1.relim()
        self.trigPlot1.autoscale_view()
        self.trigPlot2.relim()
        self.trigPlot2.autoscale_view()
        self.toolbar.update()
        self.canvasLock.release()
    
    def onCloseWindow(self, evt):
        self.onStopButton(evt)
        self.Destroy()

    def createPreferencesWindow(self):
        prefwindow = PreferencesFrame(self, "preferences")
        prefwindow.CentreOnParent(wx.BOTH)
        return prefwindow

class readerThread(threading.Thread):
    def __init__(self, frame, fd, nbChans, bufferSize=BUFFER_SIZE):
        if DEBUG: print "in readerThread __init__"
        threading.Thread.__init__(self)
        
        self.frame = frame
        self.fd = fd
        self.nbChans = nbChans
        self.bufferSize = bufferSize
        self.reading = False
        
    def run(self):
        if DEBUG: print "Reader thread running..."
        self.reading = True
        data = []
        n = 0
        while (self.reading):
            try:
                #if DEBUG: print "reading..."
                line = os.read(self.fd,self.bufferSize)
                n = len(line)/2 # 2 bytes per 'H'
                #if DEBUG: print "read %s(...) (%d bytes). data is (%d items) %s"%(line[:5],len(line),len(data),str(data))
                unpack = struct.unpack('%dH'%n,line)
                #if DEBUG: print "unpacking... got %d values"%len(unpack)
                data.extend(unpack)
                #if DEBUG: print "appending to data. data is now (%d items) %s"%(len(data),str(data))
                #if DEBUG: print "de-interleaving..."
                out,data = deinterleaveData(data,self.nbChans)
                #if DEBUG: print "returned a (%d,%d) array and kept %d for next round"%(out.shape[0],out.shape[1],len(data))
                self.frame.processNewData(out)
                evt = UpdateCanvasEvent()
                wx.PostEvent(self.frame, evt)
                time.sleep(0.1)
            except IOError:
                raise Exception("Fatal Error: %s" % (comedi.comedi_strerror(comedi.comedi_errno())))

class PreferencesFrame(wx.Frame):
    def __init__(self, parent, title):
        wx.Frame.__init__(self, parent, -1, title)

        #p = wx.Panel(self, -1)
        vertSizer = wx.BoxSizer(wx.VERTICAL)
        titleLabel = wx.StaticText(self, -1, "Comedi Configuration")
        okButton = wx.Button(self, ID_OKBUTTON, "Save Changes")
        wx.EVT_BUTTON(self, ID_OKBUTTON, self.PressOkButton)

        sizer = wx.GridSizer(6, 2, 2, 2) #rows, columns, hgap, vgap

        vertSizer.Add(titleLabel, 0, wx.ALIGN_CENTER | wx.ALL, 8)
        vertSizer.Add(sizer)
        vertSizer.Add(okButton, 0, wx.ALIGN_RIGHT | wx.ALL, 15)

        deviceNameLabel = wx.StaticText(self, -1, "Device: ")
        self.deviceNameTextBox = wx.TextCtrl(self, -1, "/dev/comedi0")
        sizer.Add(deviceNameLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.deviceNameTextBox)

        xInputLabel = wx.StaticText(self, -1, "X Input Channel: ")
        self.xInputBox = wx.Choice(self, -1, choices = ['0', '1', '2', '3', '4', '5', '6', '7'])
        self.xInputBox.SetSelection(0)
        sizer.Add(xInputLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.xInputBox)

        xInputRangeLabel = wx.StaticText(self, -1, "X Input Range: ")
        self.xInputRangeBox = wx.Choice(self, -1, choices = ['0', '1', '2', '3'])
        self.xInputRangeBox.SetSelection(0)
        sizer.Add(xInputRangeLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.xInputRangeBox)

        xInputTypeLabel = wx.StaticText(self, -1, "X Input Type: ")
        self.xInputTypeBox = wx.Choice(self, -1, choices = ['AREF_GROUND', 'AREF_COMMON', 'AREF_DIFF'])
        self.xInputTypeBox.SetSelection(0)
        sizer.Add(xInputTypeLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.xInputTypeBox)

        yInputLabel = wx.StaticText(self, -1, "Y Input Channel: ")
        self.yInputBox = wx.Choice(self, -1, choices = ['0', '1', '2', '3', '4', '5', '6', '7'])
        self.yInputBox.SetSelection(1)
        sizer.Add(yInputLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.yInputBox)

        yInputRangeLabel = wx.StaticText(self, -1, "Y Input Range: ")
        self.yInputRangeBox = wx.Choice(self, -1, choices = ['0', '1', '2', '3'])
        self.yInputRangeBox.SetSelection(0)
        sizer.Add(yInputRangeLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.yInputRangeBox)

        yInputTypeLabel = wx.StaticText(self, -1, "Y Input Type: ")
        self.yInputTypeBox = wx.Choice(self, -1, choices = ['AREF_GROUND', 'AREF_COMMON', 'AREF_DIFF'])
        self.yInputTypeBox.SetSelection(0)
        sizer.Add(yInputTypeLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.yInputTypeBox)

        inputPeriodLabel = wx.StaticText(self, -1, "Samples per second: ")
        self.inputPeriodBox = wx.TextCtrl(self, -1, "10000")
        sizer.Add(inputPeriodLabel, 0, wx.ALIGN_RIGHT)
        sizer.Add(self.inputPeriodBox)

        self.SetSizer(vertSizer)
        self.Bind(wx.EVT_CLOSE, self.OnCloseWindow)
        self.Fit()

    def PressOkButton(self, event):
        self.Close(1)

    def OnCloseWindow(self, event):
        self.MakeModal(False)
        self.Show(0)
        print self.getConfiguration()
        #self.Destroy()

    def getConfiguration(self):
        config = dict()
        config["device"] = self.deviceNameTextBox.GetValue()

        config["analog_period_ns"] = int(1e9 / float(self.inputPeriodBox.GetValue()))

        x_channel = int(self.xInputBox.GetStringSelection())
        x_range = int(self.xInputRangeBox.GetStringSelection())
        x_aref_string = self.xInputTypeBox.GetStringSelection()
        x_aref = 0
        if x_aref_string == 'AREF_GROUND':
            x_aref = comedi.AREF_GROUND
        elif x_aref_string == 'AREF_COMMON':
            x_aref = comedi.AREF_COMMON
        elif x_aref_string == 'AREF_DIFF':
            x_aref = comedi.AREF_DIFF
        else:
            x_aref = comedi.AREF_OTHER

        y_channel = int(self.yInputBox.GetStringSelection())
        y_range = int(self.yInputRangeBox.GetStringSelection())
        y_aref_string = self.yInputTypeBox.GetStringSelection()
        y_aref = 0
        if y_aref_string == 'AREF_GROUND':
            y_aref = comedi.AREF_GROUND
        elif y_aref_string == 'AREF_COMMON':
            y_aref = comedi.AREF_COMMON
        elif y_aref_string == 'AREF_DIFF':
            y_aref = comedi.AREF_DIFF
        else:
            y_aref = comedi.AREF_OTHER

        config["analog"] = [[x_channel, x_range, x_aref], [y_channel, y_range, y_aref]]
        return config

app = wx.PySimpleApp()
frame = MainWindow()
app.MainLoop()
