import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pytest

from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split

from classes.TW_Segment import TW_Point
from classes.TW_Segment import TW_Segment
from classes.TW_Utility import TW_Utility
from classes.TW_Dataset_Helper import TW_Dataset_Helper
from classes.TW_Segment_Regressor import TW_Segment_Regressor
from classes.TW_MultipleSegment_Regressor import TW_MultipleSegment_Regressor
from classes.TW_Linear_regression import TW_Linear_regression

class Test_DatasetHelper:
    TYPE_2D = "TYPE_2D"
    TYPE_3D = "TYPE_3D"

    def data_fixture_generator(self, type:str, withRegression:bool = False):
        if type == Test_DatasetHelper.TYPE_2D:
            X = [1,2,3,4,5,6,7,8,9,10]
            Y = [10,30,45,55,300,20,50,60,35,45]
            
            return [X, Y]
        elif type == Test_DatasetHelper.TYPE_3D:
            X = [[10,20], [30,40], [60,20], [92,120], [140,80], [180,100]]
            Y = [10, 20, 30, 500, 50, 60]

            return [X, Y]

    @pytest.fixture
    def data_2D_fixture(self):
       return self.data_fixture_generator(Test_DatasetHelper.TYPE_2D)
    
    @pytest.fixture
    def data_3D_fixture(self):
       return self.data_fixture_generator(Test_DatasetHelper.TYPE_3D)
    
    def fn_remove_outlier(self, fixture):
        X = fixture[0]
        Y = fixture[1]
        nbrConfirmedMaxOutliers = 0
        nbrconfirmedMinOutliers = 0

        maxValueOutliers = TW_Dataset_Helper.inspectMaxValueForOutlier(Y)
        sortedMaxValueOutliers = sorted(maxValueOutliers, key=lambda v:maxValueOutliers[v])
        for i in range(0,len(sortedMaxValueOutliers)):
            outlier = maxValueOutliers[sortedMaxValueOutliers[i]]
            valuesWithinThreshold = [j for j in Y if j > outlier]
            if len(valuesWithinThreshold) == 0:
                nbrConfirmedMaxOutliers = nbrConfirmedMaxOutliers + 1

        assert nbrConfirmedMaxOutliers == len(maxValueOutliers.values())

        minValueOutliers = TW_Dataset_Helper.inspectMinValueForOutlier(Y)
        sortedMinValueOutliers = sorted(minValueOutliers, key=lambda v:minValueOutliers[v], reverse=True)
        for i in range(0,len(sortedMinValueOutliers)):
            outlier = minValueOutliers[sortedMinValueOutliers[i]]
            valuesWithinThreshold = [j for j in Y if j < outlier]
            if len(valuesWithinThreshold) == 0:
                nbrconfirmedMinOutliers = nbrconfirmedMinOutliers + 1

        assert nbrconfirmedMinOutliers == len(minValueOutliers.values())

        cleanedDatasetX, cleanedDatasetY = TW_Dataset_Helper.removeOutlier(X, Y)
        assert len(Y) - len(cleanedDatasetY) == len(minValueOutliers) + len(maxValueOutliers)

        nbrMaxValueOutliersFoundInCleanDataSetY = 0
        nbrMaxValueOutliersFoundInCleanDataSetX = 0
        for key, value in enumerate(maxValueOutliers):
            if TW_Utility.in_array(value, cleanedDatasetY):
                nbrMaxValueOutliersFoundInCleanDataSetY = nbrMaxValueOutliersFoundInCleanDataSetY + 1

            if TW_Utility.in_array(key, cleanedDatasetX):
                nbrMaxValueOutliersFoundInCleanDataSetX = nbrMaxValueOutliersFoundInCleanDataSetX + 1

        assert nbrMaxValueOutliersFoundInCleanDataSetY == 0

        nbrMinValueOutliersFoundInCleanDataSetY = 0
        nbrMinValueOutliersFoundInCleanDataSetX = 0
        for key, value in enumerate(minValueOutliers):
            if TW_Utility.in_array(value, cleanedDatasetY):
                nbrMinValueOutliersFoundInCleanDataSetY = nbrMinValueOutliersFoundInCleanDataSetY + 1

            if TW_Utility.in_array(key, cleanedDatasetX):
                nbrMinValueOutliersFoundInCleanDataSetX = nbrMinValueOutliersFoundInCleanDataSetX + 1

        assert nbrMinValueOutliersFoundInCleanDataSetX == 0
    
    def test_removeoutlier_2D(self, data_2D_fixture):
        self.fn_remove_outlier(data_2D_fixture)

    def test_removeoutlier_3D(self, data_3D_fixture):
        self.fn_remove_outlier(data_3D_fixture)




            
