import numpy as np
import matplotlib.pyplot as plt

from classes.TW_Utility import TW_Utility
from classes.TW_Segment_Regressor import TW_Segment_Regressor
from classes.TW_Segment import TW_Segment, TW_Point

import os
import uuid
from pickle import dump, load
#from cloudpickle import dump, load

class TW_MultipleSegment_Regressor:

    def __init__(self, name = None):
        self.currentRegressor = None

        self.name = name
        if(name is None):
            self.name = str(uuid.uuid4())

        self.XRange = []
        self.XLabels = []
        self.YRange = []
        self.Ylabel:str = None
        self.Year:int = None
        self.splittedRange = False
        return

    def processDataset(self, X, Y):
        minSegmentError = None
        bestRegressor = None

        for threshold in range(5, 25, 5):
            regressor = TW_Segment_Regressor(X, Y, threshold)
            regressor.process()
            if minSegmentError is None or minSegmentError > regressor.totalError():
                minSegmentError = regressor.totalError()
                bestRegressor = regressor

        self.currentRegressor = bestRegressor
        self.XRange = X
        self.YRange = Y
        return bestRegressor
    
    def procesSplitedDataset(self, X, Y):
        regressors = {}
        errors = {}

        for i in range(0, len(X)):
            for threshold in range(5, 25, 5):
                x = np.array(X[i])
                y = np.array(Y[i])
                regressor = TW_Segment_Regressor(x, y, threshold)
                regressor.process()
                
                regressors[str(i) + '_' + str(threshold)] = regressor
                errors[str(i) + '_' + str(threshold)] = regressor.totalError()

        rangeErrors = {}
        for threshold in range(5, 25, 5):
            rangeError = [value for key, value in errors.items() if ('_' + str(threshold)) in key]
            rangeErrors[threshold] = np.sum(rangeError)

        if len(errors) > 0:
            errors = dict(sorted(errors.items(), key=lambda item:item[1]))
            firstElementKey = next(enumerate(errors.keys()))[1]

            self.currentRegressor =  regressors[firstElementKey]
            self.XRange = X
            self.YRange = Y
            self.splittedRange = True
            return regressors[firstElementKey]
        else:
            return None
        
    #TODO vérifier qu'il n'est pas de tableau de tableau dans les foreacastedValues
    def predictUsingHistoricalData(self, previousSegmentRegressor:object, xPredictFrom, xPredictTo = None, diffThreshold = 15):
        if TW_Utility.empty_or_none(previousSegmentRegressor.XRange):
            raise Exception("Previous Segement Regressor must have been initialized with a process call")
        
        if not TW_Utility.is_array(xPredictFrom):
            xPredictFrom = [xPredictFrom]
        xPredictFrom = np.array(xPredictFrom)

        if TW_Utility.is_array(xPredictFrom) and len(xPredictFrom) > 1:
            raise Exception("X must be a unique feature or an array of one feature with N Dimensions")

        if not TW_Utility.empty_or_none(xPredictTo) and not TW_Utility.is_array(xPredictTo):
            xPredictTo = [xPredictTo]
        xPredictTo = np.array(xPredictTo)

        if not TW_Utility.empty_or_none(xPredictTo) and TW_Utility.is_array(xPredictTo) and len(xPredictTo) > 1:
            raise Exception("X must be a unique feature or an array of one feature with N Dimensions")
        
        xPredictFromNbrColumns = TW_Utility.np_number_columns(xPredictFrom)

        forecastedValuesX = []
        forecastedValuesY = []
        forecastedSegments = []
        lastYearForecastSegments = []

        matchedSegment, predictedValue = previousSegmentRegressor.predict(xPredictFrom)
        if matchedSegment is not None:
            lastYearForecastSegments =  [s for s in previousSegmentRegressor.currentRegressor.segments if s.identifier >= matchedSegment.identifier]

        if len(lastYearForecastSegments) >= 1:
            currentSegment, predictedValue = self.predict(xPredictFrom)
            currentSegmentFirstPoint = currentSegment.points[0]
            currentSegmentLastPoint = currentSegment.points[len(currentSegment.points) -1]
            lastYearForecastSegmentsFirstPoint = lastYearForecastSegments[0].points[0]
            lastYearForecastSegmentsLastPoint = lastYearForecastSegments[0].points[len(lastYearForecastSegments[0].points) -1]

            diff = currentSegment.getEuclidianDistance(currentSegmentLastPoint.getPoint(), lastYearForecastSegmentsLastPoint.getPoint())
            if diff > diffThreshold:
                forecastedSegments.append(currentSegment)
                previousForecastDiffX = currentSegmentLastPoint.getX() - lastYearForecastSegmentsLastPoint.getX()
                previousForecastDiffY = currentSegmentLastPoint.getY() - lastYearForecastSegmentsLastPoint.getY()
                lastYearForecastSegments = TW_Segment.rebaseSegmentsFromIndexAndCoordinate(lastYearForecastSegments, 1, previousForecastDiffX, previousForecastDiffY)
            else:
                previousForecastDiffX = currentSegmentFirstPoint.getX() - lastYearForecastSegmentsFirstPoint.getX()
                previousForecastDiffY = currentSegmentFirstPoint.getY() - lastYearForecastSegmentsFirstPoint.getY()
                lastYearForecastSegment = TW_Segment.rebaseSegmentPoint(lastYearForecastSegments[0], previousForecastDiffX, previousForecastDiffY, TW_Segment.FIRST_POINT)
                forecastedSegments.append(lastYearForecastSegment)
           
            for i in range(1, len(lastYearForecastSegments)):
                forecastedSegments.append(lastYearForecastSegments[i])

            lastForecastedSegment = forecastedSegments[len(forecastedSegments) -1]
            lastForecastedPointX = forecastedSegments[len(forecastedSegments) -1].getX()
            
            #vérification que l'on inclue toujours le point à prédire
            lastPointToForecast = None
            if xPredictTo is not None:
                if len(xPredictTo.shape) > 1:
                    if previousSegmentRegressor.currentRegressor._isCoordinateAfterLastSegment(lastForecastedSegment, xPredictTo):
                        lastPointToForecastY = lastForecastedSegment.regression.predict(xPredictTo)

                        if not isinstance(lastPointToForecastY, np.ndarray):
                            lastPointToForecastY = np.array(lastPointToForecastY)

                        if TW_Utility.np_number_columns(lastPointToForecastY) == 1 and len(lastPointToForecastY.shape) > 1:
                            lastPointToForecastY = TW_Utility.array_last_value(lastPointToForecastY)

                        lastPointToForecast = TW_Point(xPredictTo[0], lastPointToForecastY)
                else:
                    lastForecastedPointX = lastForecastedPointX[len(lastForecastedPointX) -1]
                    diff = xPredictTo[0] - lastForecastedPointX
                    if diff > 0:
                        lastPointToForecast = TW_Point(xPredictTo[0], lastForecastedSegment.regression.predict(xPredictTo))

            if lastPointToForecast is not None:
                lastForecastedSegment.addPoint(lastPointToForecast)
                forecastedSegments[len(forecastedSegments) -1] = lastForecastedSegment

        #remplissage des résultats
        for i in range(0,len(forecastedSegments)):
            predictedXValues = forecastedSegments[i].getX()
            predictedYValues = forecastedSegments[i].getY()

            for j in range(0,len(predictedXValues)):
                shouldSkipValue = False

                if xPredictFromNbrColumns == 1 and predictedXValues[j] < xPredictFrom:
                    shouldSkipValue = True
                elif xPredictFromNbrColumns > 1:
                    origin = np.zeros(xPredictFromNbrColumns)
                    if TW_Utility._euclidianDistance(origin, predictedXValues[j]) < TW_Utility._euclidianDistance(origin, xPredictFrom[0]):
                        shouldSkipValue = True

                if xPredictTo is not None and xPredictFromNbrColumns == 1 and predictedXValues[j] > xPredictTo:
                    shouldSkipValue = True
                elif xPredictTo is not None and xPredictFromNbrColumns > 1:
                    origin = np.zeros(xPredictFromNbrColumns)
                    if TW_Utility._euclidianDistance(origin, predictedXValues[j]) > TW_Utility._euclidianDistance(origin, xPredictTo[0]):
                        shouldSkipValue = True

                if shouldSkipValue == False:
                    predictedXValue = TW_Utility.array_last_value(predictedXValues[j]) if TW_Utility.is_array(predictedXValues[j]) else predictedXValues[j]
                    forecastedValuesX.append(predictedXValue)

                    predictedYValue = TW_Utility.array_last_value(predictedYValues[j]) if TW_Utility.is_array(predictedYValues[j]) else predictedYValues[j]
                    forecastedValuesY.append(predictedYValue)

        return forecastedSegments, forecastedValuesX, forecastedValuesY
    
    def predict(self, X):
        if TW_Utility.is_array(X) and len(X) > 1:
            raise Exception("X must be a unique feature with or an array of one feature with N Dimensions")

        predictedValue = None
        if self.currentRegressor is not None:
            matchedSegment, predictedValue = self.currentRegressor.predict(X)
        
        if predictedValue is not None and TW_Utility.is_array(predictedValue):
            predictedValue = TW_Utility.array_last_value(predictedValue)

        return matchedSegment, predictedValue
    
    def showGraphics(self, existingSegments = None, forecastedSegments = None):
        if TW_Utility.empty_or_none(existingSegments):
            existingSegments = []
        elif not TW_Utility.is_array(existingSegments):
            existingSegments = [existingSegments]

        if TW_Utility.empty_or_none(forecastedSegments):
            forecastedSegments = []
        elif not TW_Utility.is_array(forecastedSegments):
            forecastedSegments = [forecastedSegments]

        if len(existingSegments) == 0 and len(forecastedSegments) == 0:
            raise Exception("existingSegments And forecastedSegments cannot be null")

        nbrDimensionX = None
        if len(existingSegments) > 0:
            nbrDimensionX = TW_Utility.np_number_columns(existingSegments[0].points[0].getX())
        elif len(forecastedSegments) > 0:
            nbrDimensionX = TW_Utility.np_number_columns(forecastedSegments[0].points[0].getX())


        if nbrDimensionX > 2:
            raise Exception("cannot handle more than 3 dimensions")
        
        emptyPointX = np.zeros(nbrDimensionX)
        segmentPoints = []
        for i in range(0, len(forecastedSegments)):
            for j in range(0, len(forecastedSegments[i].points)):
                    if nbrDimensionX < 2:
                        segmentPoints.append((
                            forecastedSegments[i].points[j].X,
                            forecastedSegments[i].points[j].Y   
                        ))
                    else:
                        point = []
                        for k in range (0, nbrDimensionX):
                            point.append(forecastedSegments[i].points[j].X[k])
                        point.append(forecastedSegments[i].points[j].Y)

                        segmentPoints.append(
                            tuple(point)
                        )

        #do not add existingSegemnts after forecasted Ones
        for i in range(0, len(existingSegments)):
            for j in range(0, len(existingSegments[i].points)):
                if nbrDimensionX < 2:
                    existingPoints = [s for s in segmentPoints if TW_Utility._euclidianDistance([s[0]], emptyPointX) < TW_Utility._euclidianDistance([existingSegments[i].points[j].X], emptyPointX)]
                else:
                    existingPoints = []
                    #existingPoints = [s for s in segmentPoints if TW_Utility._euclidianDistance(TW_Utility.pickNFirstDimensionFromTupleOrArray(s, nbrDimensionX), emptyPointX) < TW_Utility._euclidianDistance(existingSegments[i].points[j].X, emptyPointX)]
                
                if len(existingPoints) < 1:
                    if nbrDimensionX < 2:
                        segmentPoints.append((
                            existingSegments[i].points[j].X,
                            existingSegments[i].points[j].Y   
                        ))
                    else:
                        point = []
                        for k in range (0, nbrDimensionX):
                            point.append(existingSegments[i].points[j].X[k])
                        point.append(existingSegments[i].points[j].Y)

                        segmentPoints.append(
                            tuple(point)
                        )

        fig = plt.figure(figsize=(8, 5))

        if nbrDimensionX < 2:
            segmentPoints = sorted(segmentPoints, key=lambda item: item[0])
            seriesList = TW_Utility.splitPointListIntoSerie(segmentPoints)
            plt.scatter(seriesList[0],seriesList[1],marker='o')
            plt.xlabel('x')
            plt.ylabel('y')
        else:
            ax = fig.add_subplot(projection = '3d')

            segmentPoints = sorted(segmentPoints, key=lambda item: TW_Utility._euclidianDistance(TW_Utility.pickNFirstDimensionFromTupleOrArray(item, nbrDimensionX), emptyPointX))
            seriesList = TW_Utility.splitPointListIntoSerie(segmentPoints)
            ax.scatter(seriesList[0],seriesList[2],seriesList[1], marker='o')
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')

        plt.title('Graphique des données')
        plt.legend()
        
        for i in range(0, len(existingSegments)):
            forecastedSegmentsMatch = TW_Segment_Regressor.getSegmentFittingCoordinates(forecastedSegments, [existingSegments[i].start.X])
            if forecastedSegmentsMatch is None:
                if nbrDimensionX < 2:
                    x1 = existingSegments[i].start.X
                    x2 = existingSegments[i].end.X
                    y1 = existingSegments[i].start.Y
                    y2 = existingSegments[i].end.Y

                    plt.plot([x1,x2],[y1,y2], label='Segment ' + str(i))
                    plt.legend()
                else:
                    x1 = existingSegments[i].start.X[0]
                    x2 = existingSegments[i].end.X[0]
                    z1 = existingSegments[i].start.X[1]
                    z2 = existingSegments[i].end.X[1]
                    y1 = existingSegments[i].start.Y
                    y2 = existingSegments[i].end.Y

                    plt.plot([x1,x2],[y1,y2],[z1,z2], label='Segment ' + str(i))
                    plt.legend()
        
        for i in range(0, len(forecastedSegments)):
            existingSegmentsMatch = None
            extractedPoint = forecastedSegments[i].getX()
            
            for j in range(0,len(extractedPoint)):
                pointToConsider = extractedPoint[j]

                if TW_Utility.is_array(pointToConsider):
                    if not isinstance(pointToConsider, np.ndarray):
                        pointToConsider = np.array(pointToConsider)

                    if TW_Utility.np_number_columns(pointToConsider) > 1 and len(pointToConsider.shape) < 2:
                        pointToConsider = np.array([pointToConsider])

                #do not add forecasted segment if segment is segment is allready existing
                existingSegmentsMatch = TW_Segment_Regressor.getSegmentFittingCoordinates(existingSegments, pointToConsider)
                if existingSegmentsMatch is not None:
                    break

            if nbrDimensionX < 2:
                x1 = forecastedSegments[i].start.X
                x2 = forecastedSegments[i].end.X
                y1 = forecastedSegments[i].start.Y
                y2 = forecastedSegments[i].end.Y

                plt.plot([x1,x2],[y1,y2], linestyle="dashed", label=f"Forecasted Segment {i}")
                plt.legend()
            else:
                x1 = forecastedSegments[i].start.X[0]
                x2 = forecastedSegments[i].end.X[0]
                z1 = forecastedSegments[i].start.X[1]
                z2 = forecastedSegments[i].end.X[1]
                y1 = forecastedSegments[i].start.Y
                y2 = forecastedSegments[i].end.Y

                plt.plot([x1,x2],[y1,y2],[z1,z2], linestyle="dashed", label=f"Forecasted Segment {i}")
                plt.legend()

        plt.grid(True)
        plt.show()

    @staticmethod
    def serializeDir():
        return os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'serializedModel'))
    
    def serialize(self):
        serializeFolder = self.serializeDir()
        if not os.path.exists(serializeFolder):
            os.makedirs(serializeFolder)

        fileName = self.name + '.pkl'
        filePath = os.path.join(serializeFolder, fileName)

        with open(filePath, 'wb') as f:
            dump(self, f)
        
    @staticmethod
    def unserialize(fileName) -> None|object:
        serializeFolder = TW_MultipleSegment_Regressor.serializeDir()
        filePath = os.path.join(serializeFolder, fileName)

        try:
            with open(filePath, 'rb') as f:
                instance = load(f)
        except:
            instance = None

        return instance

        
        

                        
