import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold

from classes.TW_Utility import TW_Utility

class TW_Dataset_Helper:
    METHOD_STATIFIED = 'STRATIFIED'
    METHOD_KFOLD = 'KFOLD'

    def _applySplitMethod(self, n, X, Y, method):
        canSplit = False
        X_TRAIN = []
        X_TEST = []
        Y_TRAIN = []
        Y_TEST = []

        if method == self.METHOD_STATIFIED:
            splitter = StratifiedKFold(n_splits=n)
        else:
            splitter = KFold(n_splits=n)
       
        if splitter.get_n_splits(X, Y) >= n:
            canSplit = True
            
            for i, (train_index, test_index) in enumerate(splitter.split(X, Y)):
                splittedX_TRAIN = []
                splittedX_TEST = []
                splittedY_TRAIN = []
                splittedY_TEST = []

                for k in test_index:
                    splittedX_TEST.append(X[k])
                    splittedY_TEST.append(Y[k])

                for z in train_index:
                    splittedX_TRAIN.append(X[z])
                    splittedY_TRAIN.append(Y[z])
            
                X_TRAIN.append(splittedX_TRAIN)
                X_TEST.append(splittedX_TEST)
                Y_TRAIN.append(splittedY_TRAIN)
                Y_TEST.append(splittedY_TEST)
    
        return canSplit, X_TRAIN, X_TEST, Y_TRAIN, Y_TEST

    def split(self, n, X, Y, method='KFOLD', maxdiff = 25):
        canSplit = True
        splittedX_Train = []
        splittedX_Test = []
        splittedY_Train = []
        splittedY_Test = []
          
        try:
            canSplit, splittedX_Train, splittedX_Test, splittedY_Train, splittedY_Test  = self._applySplitMethod(n, X, Y,method)
        except:
            canSplit = False
            raise

        if canSplit:
            if maxdiff is not False:
                firstSplitScore = TW_Utility.dataset_total_score(splittedX_Train[0]) + TW_Utility.dataset_total_score(splittedY_Train[0])

                for i in range(1,n):
                    splitScore = TW_Utility.dataset_total_score(splittedX_Train[i]) + TW_Utility.dataset_total_score(splittedY_Train[i])
                    diff = ((max(firstSplitScore, splitScore) - min(firstSplitScore, splitScore)) / min(firstSplitScore, splitScore)) * 100
                    if diff > maxdiff:
                        canSplit = False
                        break
            
        return canSplit, splittedX_Train, splittedX_Test, splittedY_Train, splittedY_Test
    
    @staticmethod
    def _closestIndexWithinPopulation(listToEnumerate, value, mode = 'MIN'):
        diffValue = {}
        diffIndex = {}

        for i in range(0,len(listToEnumerate)):
            if mode == 'MIN':
                diff = listToEnumerate[i] - value
            elif mode == 'MAX':
                diff = value - listToEnumerate[i]

            if mode == 'MIN' and diff >= 0:
                diffValue[len(diffValue) + 1] = diff
                diffIndex[len(diffIndex) + 1] = i
            elif mode == 'MAX' and diff <= 0:
                diffValue[len(diffValue) + 1] = diff
                diffIndex[len(diffIndex) + 1] = i

        if mode == 'MIN':
            matchedIndex = [i for i,val in diffValue.items() if val == min(diffValue.values())]
        elif mode == 'MAX':
            if len(diffValue) == 0:
                maxValueIndex = np.argmax(listToEnumerate)
                diffValue[len(diffValue) + 1] = value - listToEnumerate[maxValueIndex]
                diffIndex[len(diffIndex) + 1] = maxValueIndex
            
            matchedIndex = [i for i,val in diffValue.items() if val == max(diffValue.values())]

        if len(matchedIndex) == 1:
            return diffIndex[matchedIndex[0]]
        else:
            return None

    @staticmethod
    def _populationIndexWithinRange(yEdges, valueStart, valueEnd):
        rangeStartIndex = TW_Dataset_Helper._closestIndexWithinPopulation(yEdges, valueStart, 'MIN')
        rangeEndIndex = TW_Dataset_Helper._closestIndexWithinPopulation(yEdges, valueEnd, 'MAX')

        if rangeStartIndex == None or rangeEndIndex == None:
            return None, None, None

        #hist contains len(yEdges minus 1 value)
        nbrIntervals = rangeEndIndex - rangeStartIndex
        if nbrIntervals >= 1:
            nbrIntervals = nbrIntervals - 1

        if rangeStartIndex == len(yEdges):
            rangeStartIndex = rangeStartIndex - 1

        if rangeEndIndex >= len(yEdges):
            rangeEndIndex = rangeEndIndex - 1

        return rangeStartIndex, rangeEndIndex, nbrIntervals

    '''
    @staticmethod
    def _populationNbrWithinRange(yEdges, hist, valueStart, valueEnd):
        populationWithinRange = 0
        rangeStartIndex, rangeEndIndex, nbrIntervals = TW_Dataset_Helper._populationIndexWithinRange(yEdges, valueStart, valueEnd)

        if rangeStartIndex == None:
            return 0

        if nbrIntervals < 1:
            if not TW_Utility.empty_or_none(hist, rangeEndIndex):
                populationWithinRange = populationWithinRange + hist[rangeStartIndex]
            else:
                populationWithinRange = populationWithinRange + hist[rangeStartIndex -1]
        else:
            for i in range(rangeStartIndex, rangeEndIndex):
                populationWithinRange = populationWithinRange + hist[i]

        return populationWithinRange
    '''
    
    @staticmethod
    def removeFromDataset(X, Y, valuesToRemove):
        newX = []
        newY = []

        for index in range(len(Y)):
            examinedValue = Y[index]
            if not TW_Utility.in_array(examinedValue, valuesToRemove):
                newX.append(X[index])
                newY.append(Y[index])

        return newX, newY


    @staticmethod
    def getMinAndMaxForDimension(listToExtract, n):
        min = None
        max = None

        extractedColumn = listToExtract[:, n]
        min = np.min(extractedColumn)
        max = np.max(extractedColumn)
        return min, max

    '''
    Les intervals sont calculés pour correspondre à yEdges hors une fois l'index trouvé 
    il est appliqué sur l'histogramme qui a toujours un interval de moins 1 valeur ce cas ne pose pas de problème lorsque
    range est utilisé mais ce n'est pas le cas autrement
    '''
    @staticmethod
    def removeOutlier(X, Y, outlierDistanceThresold = 60, outlierFreqThreshold = 8.5):
        newX = []
        newY = []
        totalPopulation = len(Y)
        minPopulationFreq = (1 / len(Y)) * 100
        
        outlierInMaxRange = TW_Dataset_Helper.inspectMaxValueForOutlier(Y)
        outlierInMinRange = TW_Dataset_Helper.inspectMinValueForOutlier(Y)
        totalOutlier = TW_Utility.mergeDictionnaries(outlierInMinRange, outlierInMaxRange)

        cleanedDatasetX, cleanedDatasetY = TW_Dataset_Helper.removeFromDataset(X, Y, totalOutlier)
        return np.array(cleanedDatasetX), np.array(cleanedDatasetY)

        '''
        standardDeviation = np.std(cleanedDatasetY)
        if not TW_Utility.is_array(X[0]):
            if len(Y) < 10:
                hist, xEdges, yEdges = np.histogram2d(X, Y, bins=len(Y), range=[[min(cleanedDatasetX), max(cleanedDatasetX)],[min(cleanedDatasetY), max(cleanedDatasetY) + standardDeviation]])
            else:
                hist, xEdges, yEdges = np.histogram2d(X, Y, bins=10, range=[[min(cleanedDatasetX), max(cleanedDatasetX)],[min(cleanedDatasetY), max(cleanedDatasetY) + standardDeviation]])
            
            hist = np.sum(hist, axis=0)
            for i in range(0,len(Y)):
                examinedValueRangeStart = Y[i] - standardDeviation
                examinedValueRangeEnd = Y[i] + standardDeviation
                populationWithinRange = TW_Dataset_Helper._populationNbrWithinRange(yEdges, hist, examinedValueRangeStart, examinedValueRangeEnd)

                closestValue = TW_Utility.array_closest_value(Y, Y[i])
                distanceWithClosestValue = ((max(closestValue,Y[i]) - min(closestValue,Y[i])) / max(closestValue,Y[i])) * 100
                examinedValueFrequency = (populationWithinRange / totalPopulation) * 100

                if (minPopulationFreq <= outlierFreqThreshold and  examinedValueFrequency >= outlierFreqThreshold) \
                    or distanceWithClosestValue <= outlierDistanceThresold:
                    newX.append(X[i])
                    newY.append(Y[i])

        else:
            merged = TW_Dataset_Helper.mergeAlongFirstAxis(X,Y)
            nbrColumns = TW_Utility.np_number_columns(merged)

            if len(Y) < 10:
                hist, edges = np.histogramdd(merged, bins=len(Y))
            else:
                hist, edges = np.histogramdd(merged, bins=10)

            edgeIndex = np.digitize(Y, edges[nbrColumns -1], right=True)
            for i in range(0,len(Y)):
                populationWithinRange = 0
                examinedValueRangeStart = Y[i] - standardDeviation
                examinedValueRangeEnd = Y[i] + standardDeviation
                rangeStart, rangeEnd, nbrIntervals =  TW_Dataset_Helper._populationIndexWithinRange(edges[nbrColumns -1], examinedValueRangeStart, examinedValueRangeEnd)
                if rangeStart is None:
                    continue

                if nbrIntervals < 1:
                    if not TW_Utility.empty_or_none(edgeIndex, rangeStart):
                        intervalValue = edgeIndex[rangeStart]
                    else:
                        intervalValue = edgeIndex[rangeStart -1]
                    populationWithinRange = len([j for j in edgeIndex if j == intervalValue])
                else:
                    indexSeen = []
                    for j in range(rangeStart, rangeEnd):
                        if not TW_Utility.in_array(edgeIndex[j], indexSeen):
                            populationWithinRange = populationWithinRange + len([z for z in edgeIndex if z == edgeIndex[j]])
                            indexSeen.append(edgeIndex[j])
                        
                closestValue = TW_Utility.array_closest_value(Y, Y[i])
                distanceWithClosestValue = ((max(closestValue,Y[i]) - min(closestValue,Y[i])) / max(closestValue,Y[i])) * 100
                examinedValueFrequency = (populationWithinRange / totalPopulation) * 100

                if (minPopulationFreq <= outlierFreqThreshold and  examinedValueFrequency >= outlierFreqThreshold) \
                    or distanceWithClosestValue <= outlierDistanceThresold:
                    newX.append(X[i])
                    newY.append(Y[i])

        return np.array(newX), np.array(newY)
        '''
        
    @staticmethod
    def inspectMaxValueForOutlier(Y, outlierDistanceThreshold = 100, outlierFreqThreshold = 8.5):
        outliers = {}
        totalPopulation = len(Y)

        sorted = np.sort(Y)
        highestValue = sorted[::-1]

        #highestValue = sorted[:n]
        for i in range(0,len(highestValue)):
            examinedValue = highestValue[i]
            maxValue = examinedValue + (examinedValue * (outlierDistanceThreshold / 100))

            if i < len(highestValue) - 1:
                diff = examinedValue - highestValue[i + 1]
                distThreshold = (diff / highestValue[i + 1]) * 100
            else:
                diff = highestValue[i - 1] - examinedValue
                distThreshold = (diff / examinedValue) * 100
            
            nbrSuperiorValues = [j for j in Y if j > examinedValue and j <= maxValue]
            freqSuperiorValues = (len(nbrSuperiorValues) / totalPopulation) * 100

            if distThreshold > outlierDistanceThreshold and freqSuperiorValues < outlierFreqThreshold:
                outliers[np.where(Y == examinedValue)[0][0]] = examinedValue

        return outliers
    
    @staticmethod
    def inspectMinValueForOutlier(Y, outlierDistanceThreshold = 100, outlierFreqThreshold = 8.5):
        outliers = {}
        totalPopulation = len(Y)

        lowestValue = np.sort(Y)

        for i in range(0,len(lowestValue)):
            examinedValue = lowestValue[i]
            minValue = examinedValue - (examinedValue * (outlierDistanceThreshold / 100))

            if i < len(lowestValue) - 1:
                diff = lowestValue[i + 1] - lowestValue[i]
                distThreshold = (diff / lowestValue[i]) * 100
            else:
                diff = lowestValue[i] - lowestValue[i -1]
                distThreshold = (diff / lowestValue[i -1]) * 100
            
            nbrInferiorValues = [j for j in Y if j < examinedValue and j >= minValue]
            freqInferiorValues = (len(nbrInferiorValues) / totalPopulation) * 100

            if distThreshold > outlierDistanceThreshold and freqInferiorValues < outlierFreqThreshold:
                outliers[np.where(Y == examinedValue)[0][0]] = examinedValue

        return outliers
    
    @staticmethod
    def mergeAlongFirstAxis(X, Y):
        if not isinstance(X, np.ndarray):
            X = np.array(X)

        if not isinstance(Y, np.ndarray):
            Y = np.array(Y)

        if(len(Y.shape) < 2):
            Y = np.array([Y])

        return np.concatenate((X,Y.T), axis=1)