import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pytest

from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split

from classes.TW_Segment import TW_Point
from classes.TW_Segment import TW_Segment
from classes.TW_Utility import TW_Utility
from classes.TW_Dataset_Helper import TW_Dataset_Helper
from classes.TW_Segment_Regressor import TW_Segment_Regressor
from classes.TW_MultipleSegment_Regressor import TW_MultipleSegment_Regressor
from classes.TW_Linear_regression import TW_Linear_regression

class Test_MultipleSegmentRegressor:
    TYPE_2D = "TYPE_2D"
    TYPE_3D = "TYPE_3D"

    def data_fixture_generator(self, type:str, withRegression:bool = False):
        if type == Test_MultipleSegmentRegressor.TYPE_2D:
            X = [1,2,3,4,5]
            Y = [10,20,65,45,55]
            
            return [X, Y]
        elif type == Test_MultipleSegmentRegressor.TYPE_3D:
            X = [[10,20], [30,40], [60,20], [92,120], [140,80], [180,100]]
            Y = [10, 20, 30, 40, 50, 60]

            return [X, Y]

    @pytest.fixture
    def data_2D_fixture(self):
       return self.data_fixture_generator(Test_MultipleSegmentRegressor.TYPE_2D)
    
    @pytest.fixture
    def data_3D_fixture(self):
       return self.data_fixture_generator(Test_MultipleSegmentRegressor.TYPE_3D)
    
    def test_provide_best_segment_regressor_2D(self, data_2D_fixture):
        X = np.array(data_2D_fixture[0])
        Y = np.array(data_2D_fixture[1])

        sortedRegressors = []

        for threshold in range(5, 25, 5):
            regressor = TW_Segment_Regressor(X, Y, threshold)
            regressor.process()
            sortedRegressors.append(regressor)
        
        sortedRegressors = sorted(sortedRegressors, key=lambda item:item.totalError())
        
        multipleSegmentRegressor = TW_MultipleSegment_Regressor()
        bestRegressor = multipleSegmentRegressor.processDataset(X,Y)

        assert sortedRegressors[0].thresold == bestRegressor.thresold

    def test_provide_best_segment_regressor_3D(self, data_3D_fixture):
        X = np.array(data_3D_fixture[0])
        Y = np.array(data_3D_fixture[1])

        sortedRegressors = []

        for threshold in range(5, 25, 5):
            regressor = TW_Segment_Regressor(X, Y, threshold)
            regressor.process()
            sortedRegressors.append(regressor)
        
        sortedRegressors = sorted(sortedRegressors, key=lambda item:item.totalError())
        
        multipleSegmentRegressor = TW_MultipleSegment_Regressor()
        bestRegressor = multipleSegmentRegressor.processDataset(X,Y)

        assert sortedRegressors[0].thresold == bestRegressor.thresold

