from dataclasses import dataclass, field
import os
from os.path import abspath, isfile, join, dirname
import json

from classes.TW_Cli_Command import TW_Cli_Command, TW_Cli_Args, TW_Cli_Output
from classes.TW_Utility import TW_Utility
from classes.TW_MultipleSegment_Regressor import TW_MultipleSegment_Regressor

from classes.TW_Object_Type_Validator import TW_Object_Type_Validator, TW_Object_Type_Validation

@dataclass
class TW_Process_Segment_Regression_Command(TW_Cli_Command):
    name:str = "process_segment_regression"
    args:list[TW_Cli_Args] = field(default_factory=lambda:[
        TW_Cli_Args(name="name", nullable=False),
        TW_Cli_Args(name="year", nullable=False),
        TW_Cli_Args(name="persistent", nullable=True, defaultValue=False)
    ])
    output:TW_Cli_Output = field(default_factory=lambda:TW_Cli_Output.JSON)
        
    def execute(self, args:object):
        modelName = TW_Utility.getattr(args, 'name')
        if hasattr(args, 'year') and not TW_Utility.empty_or_none(getattr(args, 'year')):
            modelName = modelName + '__' + getattr(args, 'year')

        serializedFolder = TW_MultipleSegment_Regressor.serializeDir()
        depositFolder = abspath(join(dirname(__file__), '..', 'importedJson'))
        fileName = join(depositFolder, TW_Utility.getattr(args, 'name') + '.json')
        existingModel = join(serializedFolder, modelName + '.pkl')

        if isfile(existingModel) and TW_Utility.getattr(args, 'persistent') == True:
            self.outputSuccess("Model allready processed")
            return

        if not isfile(fileName):
            self.outputError("File not found")
            return

        with(open(fileName)) as f:
            try:
                decodedJsonData = json.load(f.read())

                validator = TW_Object_Type_Validator(
                    propertiesType=[
                        TW_Object_Type_Validation(name="name", nullable=False, mainType=str),
                        TW_Object_Type_Validation(name="year", nullable=False, mainType=int),
                        TW_Object_Type_Validation(name="predicted_label", nullable=False, mainType=str),
                        TW_Object_Type_Validation(name="predicted_values", nullable=False, mainType=list, subType=float),
                        TW_Object_Type_Validation(name="predictor_label", nullable=False, mainType=str),
                        TW_Object_Type_Validation(name="predictor_label", nullable=False, mainType=list, subType=str),
                        TW_Object_Type_Validation(name="predictor_values", nullable=False, mainType=list, subType=float),
                        TW_Object_Type_Validation(name="predictor_values", nullable=False, mainType=list, subType=list, subSubType=float),
                    ]
                )

                typeValidationErrors = validator.validateType(decodedJsonData)
                if(len(typeValidationErrors) > 1):
                    errorMsg = "\n".join(typeValidationErrors.values())
                    self.outputError(errorMsg)
                    return
                
                if(len(TW_Utility.getattr(decodedJsonData, 'predicted_values')) != len(TW_Utility.getattr(decodedJsonData, 'predictor_values'))):
                    errorMsg = f"Predicted_values and predictor_values should have the same length"
                    self.outputError(errorMsg)

                modelName = TW_Utility.getattr(decodedJsonData, 'name') + '__' + TW_Utility.getattr(decodedJsonData, 'year')
                regressor = TW_MultipleSegment_Regressor(name=modelName)
                regressor.processDataset(X=TW_Utility.getattr(decodedJsonData, 'predictor_values'), Y=TW_Utility.getattr(decodedJsonData, 'predicted_values'))

                if isfile(existingModel):
                    os.remove(existingModel)

                regressor.serialize()
            except Exception as ex :
                self.outputError(self.__str__())
                return
            
        self.outputSuccess(f"{modelName} successfully processed and serialized")
        return 