
import numpy as np

from ...graphicsItems.LinearRegionItem import LinearRegionItem
from ...Qt import QtCore, QtWidgets
from ...widgets.TreeWidget import TreeWidget
from ..Node import Node
from . import functions
from .common import CtrlNode


class ColumnSelectNode(Node):
    """Select named columns from a record array or MetaArray."""
    nodeName = "ColumnSelect"
    def __init__(self, name):
        Node.__init__(self, name, terminals={'In': {'io': 'in'}})
        self.columns = set()
        self.columnList = QtWidgets.QListWidget()
        self.axis = 0
        self.columnList.itemChanged.connect(self.itemChanged)
        
    def process(self, In, display=True):
        if display:
            self.updateList(In)
                
        out = {}
        if hasattr(In, 'implements') and In.implements('MetaArray'):
            for c in self.columns:
                out[c] = In[self.axis:c]
        elif isinstance(In, np.ndarray) and In.dtype.fields is not None:
            for c in self.columns:
                out[c] = In[c]
        else:
            self.In.setValueAcceptable(False)
            raise Exception("Input must be MetaArray or ndarray with named fields")
            
        return out
        
    def ctrlWidget(self):
        return self.columnList

    def updateList(self, data):
        if hasattr(data, 'implements') and data.implements('MetaArray'):
            cols = data.listColumns()
            for ax in cols:  ## find first axis with columns
                if len(cols[ax]) > 0:
                    self.axis = ax
                    cols = set(cols[ax])
                    break
        else:
            cols = list(data.dtype.fields.keys())
                
        rem = set()
        for c in self.columns:
            if c not in cols:
                self.removeTerminal(c)
                rem.add(c)
        self.columns -= rem
                
        self.columnList.blockSignals(True)
        self.columnList.clear()
        for c in cols:
            item = QtWidgets.QListWidgetItem(c)
            item.setFlags(QtCore.Qt.ItemFlag.ItemIsEnabled|QtCore.Qt.ItemFlag.ItemIsUserCheckable)
            if c in self.columns:
                item.setCheckState(QtCore.Qt.CheckState.Checked)
            else:
                item.setCheckState(QtCore.Qt.CheckState.Unchecked)
            self.columnList.addItem(item)
        self.columnList.blockSignals(False)
        

    def itemChanged(self, item):
        col = str(item.text())
        if item.checkState() == QtCore.Qt.CheckState.Checked:
            if col not in self.columns:
                self.columns.add(col)
                self.addOutput(col)
        else:
            if col in self.columns:
                self.columns.remove(col)
                self.removeTerminal(col)
        self.update()
        
    def saveState(self):
        state = Node.saveState(self)
        state['columns'] = list(self.columns)
        return state
    
    def restoreState(self, state):
        Node.restoreState(self, state)
        self.columns = set(state.get('columns', []))
        for c in self.columns:
            self.addOutput(c)



class RegionSelectNode(CtrlNode):
    """Returns a slice from a 1-D array. Connect the 'widget' output to a plot to display a region-selection widget."""
    nodeName = "RegionSelect"
    uiTemplate = [
        ('start', 'spin', {'value': 0, 'step': 0.1}),
        ('stop', 'spin', {'value': 0.1, 'step': 0.1}),
        ('display', 'check', {'value': True}),
        ('movable', 'check', {'value': True}),
    ]
    
    def __init__(self, name):
        self.items = {}
        CtrlNode.__init__(self, name, terminals={
            'data': {'io': 'in'},
            'selected': {'io': 'out'},
            'region': {'io': 'out'},
            'widget': {'io': 'out', 'multi': True}
        })
        self.ctrls['display'].toggled.connect(self.displayToggled)
        self.ctrls['movable'].toggled.connect(self.movableToggled)
        
    def displayToggled(self, b):
        for item in self.items.values():
            item.setVisible(b)
            
    def movableToggled(self, b):
        for item in self.items.values():
            item.setMovable(b)
            
        
    def process(self, data=None, display=True):
        #print "process.."
        s = self.stateGroup.state()
        region = [s['start'], s['stop']]
        
        if display:
            conn = self['widget'].connections()
            for c in conn:
                plot = c.node().getPlot()
                if plot is None:
                    continue
                if c in self.items:
                    item = self.items[c]
                    item.setRegion(region)
                    #print "  set rgn:", c, region
                    #item.setXVals(events)
                else:
                    item = LinearRegionItem(values=region)
                    self.items[c] = item
                    #item.connect(item, QtCore.SIGNAL('regionChanged'), self.rgnChanged)
                    item.sigRegionChanged.connect(self.rgnChanged)
                    item.setVisible(s['display'])
                    item.setMovable(s['movable'])
                    #print "  new rgn:", c, region
                    #self.items[c].setYRange([0., 0.2], relative=True)
        
        if self['selected'].isConnected():
            if data is None:
                sliced = None
            elif (hasattr(data, 'implements') and data.implements('MetaArray')):
                sliced = data[0:s['start']:s['stop']]
            else:
                mask = (data['time'] >= s['start']) * (data['time'] < s['stop'])
                sliced = data[mask]
        else:
            sliced = None
        return {'selected': sliced, 'widget': self.items, 'region': region}
        
        
    def rgnChanged(self, item):
        region = item.getRegion()
        self.stateGroup.setState({'start': region[0], 'stop': region[1]})
        self.update()
        
        
class TextEdit(QtWidgets.QTextEdit):
    def __init__(self, on_update):
        super().__init__()
        self.on_update = on_update
        self.lastText = None

    def focusOutEvent(self, ev):
        text = self.toPlainText()
        if text != self.lastText:
            self.lastText = text
            self.on_update()
        super().focusOutEvent(ev)


class EvalNode(Node):
    """Return the output of a string evaluated/executed by the python interpreter.
    The string may be either an expression or a python script, and inputs are accessed as the name of the terminal. 
    For expressions, a single value may be evaluated for a single output, or a dict for multiple outputs.
    For a script, the text will be executed as the body of a function."""
    nodeName = 'PythonEval'
    
    def __init__(self, name):
        Node.__init__(self, name, 
            terminals = {
                'input': {'io': 'in', 'renamable': True, 'multiable': True},
                'output': {'io': 'out', 'renamable': True, 'multiable': True},
            },
            allowAddInput=True, allowAddOutput=True)
        
        self.ui = QtWidgets.QWidget()
        self.layout = QtWidgets.QGridLayout()
        self.text = TextEdit(self.update)
        self.text.setTabStopWidth(30)
        self.text.setPlainText("# Access inputs as args['input_name']\nreturn {'output': None} ## one key per output terminal")
        self.layout.addWidget(self.text, 1, 0, 1, 2)
        self.ui.setLayout(self.layout)
        
    def ctrlWidget(self):
        return self.ui
        
    def setCode(self, code):
        # unindent code; this allows nicer inline code specification when 
        # calling this method.
        ind = []
        lines = code.split('\n')
        for line in lines:
            stripped = line.lstrip()
            if len(stripped) > 0:
                ind.append(len(line) - len(stripped))
        if len(ind) > 0:
            ind = min(ind)
            code = '\n'.join([line[ind:] for line in lines])
        
        self.text.clear()
        self.text.insertPlainText(code)

    def code(self):
        return self.text.toPlainText()
        
    def process(self, display=True, **args):
        l = locals()
        l.update(args)
        ## try eval first, then exec
        try:  
            text = self.text.toPlainText().replace('\n', ' ')
            output = eval(text, globals(), l)
        except SyntaxError:
            fn = "def fn(**args):\n"
            run = "\noutput=fn(**args)\n"
            text = fn + "\n".join(["    "+l for l in self.text.toPlainText().split('\n')]) + run
            ldict = locals()
            exec(text, globals(), ldict)
            output = ldict['output']
        except:
            print(f"Error processing node: {self.name()}")
            raise
        return output
        
    def saveState(self):
        state = Node.saveState(self)
        state['text'] = self.text.toPlainText()
        #state['terminals'] = self.saveTerminals()
        return state
        
    def restoreState(self, state):
        Node.restoreState(self, state)
        self.setCode(state['text'])
        self.restoreTerminals(state['terminals'])
        self.update()

        
class ColumnJoinNode(Node):
    """Concatenates record arrays and/or adds new columns"""
    nodeName = 'ColumnJoin'
    
    def __init__(self, name):
        Node.__init__(self, name, terminals = {
            'output': {'io': 'out'},
        })
        
        #self.items = []
        
        self.ui = QtWidgets.QWidget()
        self.layout = QtWidgets.QGridLayout()
        self.ui.setLayout(self.layout)
        
        self.tree = TreeWidget()
        self.addInBtn = QtWidgets.QPushButton('+ Input')
        self.remInBtn = QtWidgets.QPushButton('- Input')
        
        self.layout.addWidget(self.tree, 0, 0, 1, 2)
        self.layout.addWidget(self.addInBtn, 1, 0)
        self.layout.addWidget(self.remInBtn, 1, 1)

        self.addInBtn.clicked.connect(self.addInput)
        self.remInBtn.clicked.connect(self.remInput)
        self.tree.sigItemMoved.connect(self.update)
        
    def ctrlWidget(self):
        return self.ui
        
    def addInput(self):
        #print "ColumnJoinNode.addInput called."
        term = Node.addInput(self, 'input', renamable=True, removable=True, multiable=True)
        #print "Node.addInput returned. term:", term
        item = QtWidgets.QTreeWidgetItem([term.name()])
        item.term = term
        term.joinItem = item
        #self.items.append((term, item))
        self.tree.addTopLevelItem(item)

    def remInput(self):
        sel = self.tree.currentItem()
        term = sel.term
        term.joinItem = None
        sel.term = None
        self.tree.removeTopLevelItem(sel)
        self.removeTerminal(term)
        self.update()

    def process(self, display=True, **args):
        order = self.order()
        vals = []
        for name in order:
            if name not in args:
                continue
            val = args[name]
            if isinstance(val, np.ndarray) and len(val.dtype) > 0:
                vals.append(val)
            else:
                vals.append((name, None, val))
        return {'output': functions.concatenateColumns(vals)}

    def order(self):
        return [str(self.tree.topLevelItem(i).text(0)) for i in range(self.tree.topLevelItemCount())]

    def saveState(self):
        state = Node.saveState(self)
        state['order'] = self.order()
        return state
        
    def restoreState(self, state):
        Node.restoreState(self, state)
        inputs = self.inputs()

        ## Node.restoreState should have created all of the terminals we need
        ## However: to maintain support for some older flowchart files, we need
        ## to manually add any terminals that were not taken care of.
        for name in [n for n in state['order'] if n not in inputs]:
            Node.addInput(self, name, renamable=True, removable=True, multiable=True)
        inputs = self.inputs()

        order = [name for name in state['order'] if name in inputs]
        for name in inputs:
            if name not in order:
                order.append(name)
        
        self.tree.clear()
        for name in order:
            term = self[name]
            item = QtWidgets.QTreeWidgetItem([name])
            item.term = term
            term.joinItem = item
            #self.items.append((term, item))
            self.tree.addTopLevelItem(item)

    def terminalRenamed(self, term, oldName):
        Node.terminalRenamed(self, term, oldName)
        item = term.joinItem
        item.setText(0, term.name())
        self.update()
        
        
class Mean(CtrlNode):
    """Calculate the mean of an array across an axis.
    """
    nodeName = 'Mean'
    uiTemplate = [
        ('axis', 'intSpin', {'value': 0, 'min': -1, 'max': 1000000}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        ax = None if s['axis'] == -1 else s['axis']
        return data.mean(axis=ax)


class Max(CtrlNode):
    """Calculate the maximum of an array across an axis.
    """
    nodeName = 'Max'
    uiTemplate = [
        ('axis', 'intSpin', {'value': 0, 'min': -1, 'max': 1000000}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        ax = None if s['axis'] == -1 else s['axis']
        return data.max(axis=ax)


class Min(CtrlNode):
    """Calculate the minimum of an array across an axis.
    """
    nodeName = 'Min'
    uiTemplate = [
        ('axis', 'intSpin', {'value': 0, 'min': -1, 'max': 1000000}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        ax = None if s['axis'] == -1 else s['axis']
        return data.min(axis=ax)


class Stdev(CtrlNode):
    """Calculate the standard deviation of an array across an axis.
    """
    nodeName = 'Stdev'
    uiTemplate = [
        ('axis', 'intSpin', {'value': -0, 'min': -1, 'max': 1000000}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        ax = None if s['axis'] == -1 else s['axis']
        return data.std(axis=ax)


class Index(CtrlNode):
    """Select an index from an array axis.
    """
    nodeName = 'Index'
    uiTemplate = [
        ('axis', 'intSpin', {'value': 0, 'min': 0, 'max': 1000000}),
        ('index', 'intSpin', {'value': 0, 'min': 0, 'max': 1000000}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        ax = s['axis']
        ind = s['index']
        if ax == 0:
            # allow support for non-ndarray sequence types
            return data[ind]
        else:
            return data.take(ind, axis=ax)
        

class Slice(CtrlNode):
    """Select a slice from an array axis.
    """
    nodeName = 'Slice'
    uiTemplate = [
        ('axis', 'intSpin', {'value': 0, 'min': 0, 'max': 1e6}),
        ('start', 'intSpin', {'value': 0, 'min': -1e6, 'max': 1e6}),
        ('stop', 'intSpin', {'value': -1, 'min': -1e6, 'max': 1e6}),
        ('step', 'intSpin', {'value': 1, 'min': -1e6, 'max': 1e6}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        ax = s['axis']
        start = s['start']
        stop = s['stop']
        step = s['step']
        if ax == 0:
            # allow support for non-ndarray sequence types
            return data[start:stop:step]
        else:
            sl = [slice(None) for i in range(data.ndim)]
            sl[ax] = slice(start, stop, step)
            return data[sl]
        

class AsType(CtrlNode):
    """Convert an array to a different dtype.
    """
    nodeName = 'AsType'
    uiTemplate = [
        ('dtype', 'combo', {'values': ['float', 'int', 'float32', 'float64', 'float128', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64'], 'index': 0}),
    ]
    
    def processData(self, data):
        s = self.stateGroup.state()
        return data.astype(s['dtype'])
