from dataclasses import dataclass, field
from os.path import abspath, isfile, join, dirname
import os
import json
import re

from classes.TW_Cli_Command import TW_Cli_Command, TW_Cli_Args, TW_Cli_Output
from classes.TW_Utility import TW_Utility
from classes.TW_MultipleSegment_Regressor import TW_MultipleSegment_Regressor

from classes.TW_Object_Type_Validator import TW_Object_Type_Validator, TW_Object_Type_Validation

@dataclass
class TW_Predict_Segment_Regression_Command(TW_Cli_Command):
    name:str = "predict_segment_regression"
    args:list[TW_Cli_Args] = field(default_factory=lambda:[
        TW_Cli_Args(name="name", nullable=False),
        TW_Cli_Args(name="year", nullable=False),
        TW_Cli_Args(name="predict_from", nullable=False),
        TW_Cli_Args(name="predict_to", nullable=True, defaultValue=None)
    ])
    output:TW_Cli_Output = field(default_factory=lambda:TW_Cli_Output.JSON)

    def __parsePredictArg(predictArg:str):
        if(TW_Utility.empty_or_none(predictArg)):
            return None

        if re.match(r'\[',predictArg) is not None:
            predictArg = predictArg.replace('[', '')
            predictArg = predictArg.replace(']', '')
            predictArg = predictArg.split(',')
            predictArg = [float(f) for f in predictArg]
        else:
            predictArg = float(predictArg)

        return predictArg
        
    def execute(self, args:object):
        predictedData = {"predictor_values":[], "predicted_values":[]}

        serializedFolder = TW_MultipleSegment_Regressor.serializeDir()
        modelFilename = TW_Utility.getattr(args, 'name') + '__' + getattr(args, 'year')
        modelFilename = modelFilename + '.pkl'

        previousYearModelFilename = TW_Utility.getattr(args, 'name') + '__' + (int(TW_Utility.getattr(args, 'year')) - 1)
        previousYearModelFilename = previousYearModelFilename + '.pkl'

        if not isfile(join(serializedFolder, previousYearModelFilename)) or not isfile(join(serializedFolder, modelFilename)):
            self.outputError("Model not found")
            return

        model = TW_MultipleSegment_Regressor.unserialize(modelFilename)
        if model is None:
            self.outputError("Model unserialize error")
            return
        
        previousYearModel = TW_MultipleSegment_Regressor.unserialize(previousYearModelFilename)
        predictFrom = self.__parsePredictArg(TW_Utility.getattr(args, 'predict_from'))
        predictTo = self.__parsePredictArg(TW_Utility.getattr(args, 'predict_to')) if TW_Utility.hasattr(args, 'predict_to') else None
       
        if previousYearModel is not None: 
            forecastedSegments, forecastedValuesX, forecastedValuesY = model.predictUsingHistoricalData(previousYearModel, predictFrom, predictTo)
            predictedData['predictor_values'] = forecastedValuesX
            predictedData['predicted_values'] = forecastedValuesY
        else:
            forecastedSegment, forecastedValueY = model.predict(predictFrom)
            predictedData['predictor_values'] = predictFrom
            predictedData['predicted_values'] = forecastedValueY

        self.outputSuccess(predictedData)
        return 