import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pytest

from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split

from classes.TW_Segment import TW_Point
from classes.TW_Segment import TW_Segment
from classes.TW_Utility import TW_Utility
from classes.TW_Dataset_Helper import TW_Dataset_Helper
from classes.TW_Segment_Regressor import TW_Segment_Regressor
from classes.TW_MultipleSegment_Regressor import TW_MultipleSegment_Regressor
from classes.TW_Linear_regression import TW_Linear_regression

class Test_SegmentRegressor:
    TYPE_2D = "TYPE_2D"
    TYPE_3D = "TYPE_3D"

    def segment_fixture_generator(self, type:str, withRegression:bool = False):
        if type == Test_SegmentRegressor.TYPE_2D:
            segment1 = TW_Segment('seg1')
            segment1.points = [TW_Point(10,20), TW_Point(30,40)]

            if withRegression:
                segment1.regression = TW_Linear_regression(np.array(segment1.getX()), np.array(segment1.getY()))
                segment1.regression.process()

            segment1.start = segment1.points[0]
            segment1.end = segment1.points[1]

            segment2 = TW_Segment('seg1')
            segment2.points = [TW_Point(50,80), TW_Point(70,90)]

            if withRegression:
                segment2.regression = TW_Linear_regression(np.array(segment2.getX()), np.array(segment2.getY()))
                segment2.regression.process()

            segment2.start = segment2.points[0]
            segment2.end = segment2.points[1]
            return [segment1, segment2]
        elif type == Test_SegmentRegressor.TYPE_3D:
            segment1 = TW_Segment('seg1')
            segment1.points = [TW_Point([10,20],20), TW_Point([30,40],40)]

            if withRegression:
                segment1.regression = TW_Linear_regression(np.array(segment1.getX()), np.array(segment1.getY()))
                segment1.regression.process()

            segment1.start = segment1.points[0]
            segment1.end = segment1.points[1]

            segment2 = TW_Segment('seg1')
            segment2.points = [TW_Point([50,60],80), TW_Point([70,80],90)]

            if withRegression:
                segment2.regression = TW_Linear_regression(np.array(segment2.getX()), np.array(segment2.getY()))
                segment2.regression.process()

            segment2.start = segment2.points[0]
            segment2.end = segment2.points[1]
            return [segment1, segment2]

    @pytest.fixture
    def segment_2D_fixture(self):
       return self.segment_fixture_generator(Test_SegmentRegressor.TYPE_2D, False)
    
    @pytest.fixture
    def segment_2D_fixture_with_regression(self):
       return self.segment_fixture_generator(Test_SegmentRegressor.TYPE_2D, True)
    
    @pytest.fixture
    def segment_3D_fixture(self):
        return self.segment_fixture_generator(Test_SegmentRegressor.TYPE_3D, False)
    
    @pytest.fixture
    def segment_3D_fixture_with_regression(self):
        return self.segment_fixture_generator(Test_SegmentRegressor.TYPE_3D, True)
    
    def test_segment_fitting_coordinates_2D(self, segment_2D_fixture):
        matchedSegment1 = TW_Segment_Regressor.getSegmentFittingCoordinates(segment_2D_fixture, 20)
        matchedSegment2 = TW_Segment_Regressor.getSegmentFittingCoordinates(segment_2D_fixture, 60)
        matchedSegment3 = TW_Segment_Regressor.getSegmentFittingCoordinates(segment_2D_fixture, 90)

        assert matchedSegment1 is not None and matchedSegment1.points[0].getX() == 10
        assert matchedSegment2 is not None and matchedSegment2.points[0].getX() == 50
        assert matchedSegment3 is None

    def test_segment_fitting_coordinates_3D(self, segment_3D_fixture):
        matchedSegment1 = TW_Segment_Regressor.getSegmentFittingCoordinates(segment_3D_fixture, [[10,20]])
        matchedSegment2 = TW_Segment_Regressor.getSegmentFittingCoordinates(segment_3D_fixture, [[60,75]])
        matchedSegment3 = TW_Segment_Regressor.getSegmentFittingCoordinates(segment_3D_fixture, [[60,100]])

        assert matchedSegment1 is not None and matchedSegment1.points[0].getX() == [10,20]
        assert matchedSegment2 is not None and matchedSegment2.points[0].getX() == [50,60]
        assert matchedSegment3 is None

    def test_is_coordinate_before_first_segment_2D(self, segment_2D_fixture):
        x = 2
        y = 7
        segmentRegressor = TW_Segment_Regressor(x,y,5)
        isCoordinateBeforeFirstSegment = segmentRegressor._isCoordinateBeforeFirstSegment(segment_2D_fixture[0],x)
        assert isCoordinateBeforeFirstSegment == True

        isCoordinateBeforeFirstSegment2 = segmentRegressor._isCoordinateBeforeFirstSegment(segment_2D_fixture[0],11)
        assert isCoordinateBeforeFirstSegment2 == False

    def test_is_coordinate_before_first_segment_3D(self, segment_3D_fixture):
        x = [[5,8]]
        y = 8

        segmentRegressor = TW_Segment_Regressor(x,y,5)
        isCoordinateBeforeFirstSegment = segmentRegressor._isCoordinateBeforeFirstSegment(segment_3D_fixture[0],x)
        assert isCoordinateBeforeFirstSegment == True

        isCoordinateBeforeFirstSegment2 = segmentRegressor._isCoordinateBeforeFirstSegment(segment_3D_fixture[0],[[55,65]])
        assert isCoordinateBeforeFirstSegment2 == False

    def test_is_coordinate_after_last_segment_2D(self, segment_2D_fixture):
        x = 35
        y = 45

        segmentRegressor = TW_Segment_Regressor(x,y,5)
        isCoordinateAfterFirstSegment = segmentRegressor._isCoordinateAfterLastSegment(segment_2D_fixture[0],x)
        assert isCoordinateAfterFirstSegment == True

        isCoordinateAfterFirstSegment2 = segmentRegressor._isCoordinateAfterLastSegment(segment_2D_fixture[0],11)
        assert isCoordinateAfterFirstSegment2 == False

    def test_is_coordinate_after_last_segment_3D(self, segment_3D_fixture):
        x = [35,35]
        y = 45
        segmentRegressor = TW_Segment_Regressor(x,y,5)
        isCoordinateAfterFirstSegment = segmentRegressor._isCoordinateAfterLastSegment(segment_3D_fixture[0],[[40,50]])
        assert isCoordinateAfterFirstSegment == True

        isCoordinateAfterFirstSegment2 = segmentRegressor._isCoordinateAfterLastSegment(segment_3D_fixture[0],[[35,35]])
        assert isCoordinateAfterFirstSegment2 == False

    def test_predict_match_closest_segment_when_not_fitting_not_before_and_not_after_2D(self, segment_2D_fixture_with_regression):
        segmentRegressor = TW_Segment_Regressor(35,35,5)
        segmentRegressor.segments = segment_2D_fixture_with_regression
        matchedSegment, predictedValue = segmentRegressor.predict(35)
        assert matchedSegment.points[0].getX() == 10

        segmentRegressor = TW_Segment_Regressor(60,82,5)
        segmentRegressor.segments = segment_2D_fixture_with_regression
        matchedSegment, predictedValue = segmentRegressor.predict(60)
        assert matchedSegment.points[0].getX() == 50

        segmentRegressor = TW_Segment_Regressor(110,140,5)
        segmentRegressor.segments = segment_2D_fixture_with_regression
        matchedSegment, predictedValue = segmentRegressor.predict(110)
        assert matchedSegment.points[0].getX() == 50

    def test_predict_match_closest_segment_when_not_fitting_not_before_and_not_after_3D(self, segment_3D_fixture_with_regression):
        segmentRegressor = TW_Segment_Regressor([45,45],45,5)
        segmentRegressor.segments = segment_3D_fixture_with_regression
        matchedSegment, predictedValue = segmentRegressor.predict([[45,45]])
        assert matchedSegment.start.X == [10,20]

        segmentRegressor = TW_Segment_Regressor([100,100],100,5)
        segmentRegressor.segments = segment_3D_fixture_with_regression
        matchedSegment, predictedValue = segmentRegressor.predict([[100,100]])
        assert matchedSegment.start.X == [50,60]