'''
triggerPlot() is an extension of a matplotlib plot that displays
'paged' data, optionally triggered by various conditions, and with remanence

Created on Jun 11, 2012
@author: manuel
'''

import numpy
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from Utils import enum

DEBUG = False

triggerTypes = enum(AUTO=0, RISINGTHRESHOLD=1, FALLINGTHRESHOLD=2)


class TriggerPlot(Axes):
    def __init__(self, fig, rect,
                 axisbg=None,  # defaults to rc axes.facecolor
                 frameon=True,
                 sharex=None,  # use Axes instance's xaxis info
                 sharey=None,  # use Axes instance's yaxis info
                 label='',
                 xscale=None,
                 yscale=None,
                 linewidth=2.0,
                 linecolor='b',
                 sampleFreq=1000.):

        self._remanence = 5
        self._triggerMode = triggerTypes.RISINGTHRESHOLD
        self._autoDefineTrigger = True
        self._triggerLevel = 0.0
        self._linecolor = linecolor
        self._linewidth = linewidth
        self._units = 'V'
        self._scaling = 1.0  # converts the value in Volts into a value in 'units'.
        self._offset = 0.0
        self._sampleFreq = sampleFreq
        self._timeWidth = 1.0  # in seconds
        self._bufferSize = self._sampleFreq * self._timeWidth

        self._peekBufferSize = self._bufferSize * (self._remanence + 1)
        self._peekBuffer = None  # init in self.__initialize()
        self.__peekBufferPos = 0
        self.__relTriggerLevel = 0.75
        self.__triggerLabel = None  # init in self.__initialize()
        self.lines = []  # init in self.__initialize()
        self.__bottomY = 0.0
        self.__topY = 1.0

        # initialize
        Axes.__init__(self, fig, rect, axisbg, frameon, sharex, sharey, label, xscale, yscale)
        self.__triggerLabel = self.text(0, 0, 'T', color='w', backgroundcolor='k', va='center', ha='right',
                                        size='x-small')
        self.__initialize()

    def getRemanence(self):
        """
        the number of 'paged' displays to keep on screen.
        Older data fade progressively to background
        """
        return self._remanence

    def getTriggerMode(self):
        """
        define how the trace is triggered:
            AUTO
                the trace are added without delay
            RISINGTHRESHOLD
                the system waits for the trace to cross a
                previously defined threshold to resume
                    displaying further traces
        """
        return self._triggerMode

    def isAutoDefineTrigger(self):
        """
        define whether the trigger level is defined
        automatically (75% of the peak to peak amplitude
        of the buffered data) or by the user
        """
        return self._autoDefineTrigger

    def getTriggerLevel(self):
        """value used for the trigger if it is not set automatically"""
        return self._triggerLevel

    def getLinecolor(self):
        """"
        define the color of the most recent trace.
        If remanence is enabled, the other traces progressively fade to background
        """
        return self._linecolor

    def getLinewidth(self):
        """
        define the width of the most recent trace.
        If remanence is enabled, the other traces are half that size
        """
        return self._linewidth

    def getUnits(self):
        """the units displayed of the Y axis"""
        return self._units

    def getScaling(self):
        """
        the scaling factor in UNITS/Volts used to convert the
        value in Volts into a value in UNITS
        """
        return self._scaling

    def getOffset(self):
        """a value in UNITS giving the value of the signal when it is at 0V"""
        return self._offset

    def getSampleFreq(self):
        """the frequency at which the signal is sampled"""
        return self._sampleFreq

    def getTimeWidth(self):
        """the width (in seconds) of the screen"""
        return self._timeWidth

    def setRemanence(self, value):
        if DEBUG: print 'in TriggerPlot.setRemanence(%s) [%s]' % (value, self)
        if value <= 0:
            value = 1
        self._remanence = value
        self.__initialize()

    def setTriggerMode(self, value):
        if DEBUG: print 'in TriggerPlot.setTriggerMode(%s) [%s]' % (value, self)
        self._triggerMode = value
        self.__initialize()

    def setAutoDefineTrigger(self, value):
        if DEBUG: print 'in TriggerPlot.setAutoDefineTrigger(%s) [%s]' % (value, self)
        self._autoDefineTrigger = value

    def setTriggerLevel(self, value):
        if DEBUG: print 'in TriggerPlot.setTriggerLevel(%s) [%s]' % (value, self)
        # if DEBUG: print "setting threshold at value %f"%inValue
        self._triggerLevel = value
        background = 'k'
        xPos, _ = self.get_xlim()
        yPos = value
        yLim = self.get_ylim()
        if yPos > yLim[1]:
            # if DEBUG: print "__trigger marker pos past axes limit, setting at top value %f"%yLim[1]
            yPos = yLim[1]
            background = '0.5'
        if yPos < yLim[0]:
            # if DEBUG: print "__trigger marker pos past axes limit, setting at bottom value %f"%yLim[0]
            yPos = yLim[0]
            background = '0.5'
        # if DEBUG: print 'putting __trigger marker %s at (%f,%f)'%(self.__triggerLabel,xPos,yPos)
        self.__triggerLabel.set_position((xPos, yPos))
        self.__triggerLabel.set_backgroundcolor(background)

    def setLinecolor(self, value):
        if DEBUG: print 'in TriggerPlot.setLineColor(%s) [%s]' % (value, self)
        self._linecolor = value
        self.__initialize()

    def setLinewidth(self, value):
        if DEBUG: print 'in TriggerPlot.setLineWidth(%s) [%s]' % (value, self)
        self._linewidth = value
        self.__initialize()

    def setUnits(self, value):
        if DEBUG: print 'in TriggerPlot.setUnits(%s) [%s]' % (value, self)
        self._units = value

    def setScaling(self, value):
        if DEBUG: print 'in TriggerPlot.setScaling(%s) [%s]' % (value, self)
        self._scaling = value

    def setOffset(self, value):
        if DEBUG: print 'in TriggerPlot.setOffset(%s) [%s]' % (value, self)
        self._offset = value

    def setSampleFreq(self, value):
        if DEBUG: print 'in TriggerPlot.setSampleFreq(%s) [%s]' % (value, self)
        self._sampleFreq = value
        self._bufferSize = self._sampleFreq * self._timeWidth
        self._peekBufferSize = self._remanence * self._bufferSize
        self.__initialize()

    def setTimeWidth(self, value):
        # if DEBUG: print 'in TriggerPlot.setTimeWidth(%s) [%s]'%(value,self)
        self._timeWidth = value
        self._bufferSize = self._sampleFreq * self._timeWidth
        self._peekBufferSize = self._remanence * self._bufferSize
        self.__initialize()

    def reset(self):
        # makes sure one line is always present
        self.lines = []
        self.lines = self.plot([], '-', color=self._linecolor, linewidth=self._linewidth)
        # reset peek buffer
        self._peekBuffer = numpy.zeros(self._peekBufferSize) * numpy.nan
        self.__peekBufferPos = 0

    def __initialize(self):
        # if DEBUG: print 'in TriggerPlot.__initialize() [%s]'%(self)
        self.set_xlim(0, self._timeWidth)
        if not self.get_autoscaley_on():
            # if DEBUG: print 'autoscale is OFF, resetting the y axis limits'
            self.set_ylim(self.__bottomY, self.__topY)
        else:
            pass
            # if DEBUG: print 'autoscale is ON, doing nothing'
        # makes sure one line is always present
        self.lines = []
        self.lines = self.plot([], '-', color=self._linecolor, linewidth=self._linewidth)
        # clear the peek buffer
        self._peekBuffer = numpy.zeros(self._peekBufferSize) * numpy.nan
        self.__peekBufferPos = 0
        if self._triggerMode == triggerTypes.AUTO:
            self.__triggerLabel.set_visible(False)
        else:
            self.__triggerLabel.set_visible(True)
            if self._triggerMode == triggerTypes.RISINGTHRESHOLD:
                pass
            elif self._triggerMode == triggerTypes.FALLINGTHRESHOLD:
                self.__triggerLabel.set_rotation(180)
            else:
                raise LookupError()

    def __addDataToPeekBuffer(self, inData):
        '''
        fills the peek buffer (a sliding buffer containing all the data transfered)
        to be able to keep track of the signal received
        '''
        nbPoints = len(inData)
        if nbPoints > self._peekBufferSize:
            self._peekBuffer = inData[:self._peekBufferSize]
        else:
            self._peekBuffer[:-1 * nbPoints] = self._peekBuffer[nbPoints:]  # shift content nbPoints to the left
            self._peekBuffer[-1 * nbPoints:] = inData

    def __rescaleData(self, inData):
        return inData * self._scaling + self._offset

    def processNewData(self, inData):
        if DEBUG: print "#### in TriggerPlot.processNewData() ####"
        inData = self.__rescaleData(inData)
        self.__addDataToPeekBuffer(inData)
        nbPoints = len(inData)
        yData = self.lines[-1].get_ydata()
        currPos = len(yData)
        # if DEBUG: print "received %d data points. current position in buffer: %s" % (nbPoints, currPos)
        maxPoints = self._bufferSize - currPos
        if nbPoints <= maxPoints:
            maxPoints = nbPoints
        # if DEBUG: print "will add %d points to the buffer" % (maxPoints)
        yData = numpy.append(yData, inData[:maxPoints])
        nbPoints = float(len(yData))
        # if DEBUG: print 'replacing buffer with %d points from %f to %f s'%(nbPoints,0,nbPoints/self._sampleFreq)
        self.lines[-1].set_data(numpy.linspace(0, nbPoints / self._sampleFreq, nbPoints), yData)  # @UndefinedVariable
        currPos += maxPoints

        # if DEBUG: print "current position in the buffer is now %d" % (currPos)
        if currPos > self._bufferSize - 1:
            # we've filled the whole buffer, we need to add a new line
            # if DEBUG: print "wrapping around..."

            if len(self.lines) >= self._remanence + 1:
                # if DEBUG: print "reached max number of lines (%d)" % (self.remanence + 1)
                self.lines.pop(0)

            # if DEBUG:
            #    nPointsLeft = len(inData[maxPoints:])
            #    print "we have %s points left to add to a new line" % nPointsLeft
            if self._autoDefineTrigger:
                self.__autoDefineThreshold()
            ret = self.__waitForTrigger(inData[maxPoints:])
            nbPoints = float(len(ret))
            if nbPoints > 0:
                line = Line2D(numpy.linspace(0,
                                             nbPoints / self._sampleFreq,
                                             nbPoints),
                              ret,
                              color=self._linecolor)
                # if DEBUG: print "adding line %s" % (line)
                self.add_line(line)
            # if DEBUG:
            #    print "current lines: %d" % (len(self.lines))
            #    print self.lines
            self.__updateLineColor()
            if self.get_autoscaley_on():
                self.relim()
                self.autoscale_view(tight=None, scalex=False, scaley=True)

    def __updateLineColor(self):
        for i in range(len(self.lines)):
            self.lines[i].set_color(self._linecolor)
            self.lines[i].set_alpha((i + 1) * (1.0 / len(self.lines)))
            self.lines[i].set_lw(self._linewidth / 2.0)
        self.lines[-1].set_lw(self._linewidth)

    def __autoDefineThreshold(self):
        # if DEBUG: print "trying to determine threshold automatically"
        minValue = self._peekBuffer.min()
        maxValue = self._peekBuffer.max()
        # if DEBUG: print "peeking into data [%f-%f]"%(minValue,maxValue)
        overallRange = (maxValue - minValue)
        if self.triggerMode == triggerTypes.AUTO:
            pass
        elif self.triggerMode == triggerTypes.RISINGTHRESHOLD:
            self.setTriggerLevel(minValue + self.__relTriggerLevel * overallRange)
        elif self.triggerMode == triggerTypes.FALLINGTHRESHOLD:
            self.setTriggerLevel(maxValue - self.__relTriggerLevel * overallRange)
        else:
            raise LookupError()

    def __waitForTrigger(self, inData):
        if self._triggerMode == triggerTypes.AUTO:
            return inData
        elif self._triggerMode == triggerTypes.RISINGTHRESHOLD:
            # if DEBUG: print "waiting for __trigger level %f"%self._triggerLevel
            a, = numpy.where(inData > self._triggerLevel)  # @UndefinedVariable
            if len(a) > 0 and a[0] > 0 and inData[a[0] - 1] <= self._triggerLevel:
                # if DEBUG: print "Threshold crossed at index %d. returning %d points "%(a[0],len(inData[a[0]:]))
                return inData[a[0]:]
            else:
                # if DEBUG: print "trashing %d points"%(len(inData))
                return []
        else:
            raise LookupError("Invalid Trigger type %s" % str(self._triggerMode))

    def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
        if DEBUG: print 'in TriggerPlot.set_ylim(%s,%s) [%s]' % (bottom, top, self)
        if bottom is not None: self.__bottomY = bottom
        if top is not None: self.__topY = top
        Axes.set_ylim(self, bottom, top, emit, auto)

    remanence = property(getRemanence, setRemanence, None)
    triggerMode = property(getTriggerMode, setTriggerMode, None)
    autoDefineTrigger = property(isAutoDefineTrigger, setAutoDefineTrigger, None)
    triggerLevel = property(getTriggerLevel, setTriggerLevel, None)
    linecolor = property(getLinecolor, setLinecolor, None)
    linewidth = property(getLinewidth, setLinewidth, None)
    units = property(getUnits, setUnits, None)
    scaling = property(getScaling, setScaling, None)
    offset = property(getOffset, setOffset, None)
    sampleFreq = property(getSampleFreq, setSampleFreq, None)
    timeWidth = property(getTimeWidth, setTimeWidth, None)
