# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT

import glob
import os
from pathlib import Path
import copy
from typing import List

from tqdm import tqdm
import multiprocess as mp
from concurrent.futures import ProcessPoolExecutor

from cli.args import get_sample_range_args, DefaultFileIds, ValidMultiFileSpec
from cli.output_handler import OutputHandler
from cli.console_output import CliConsoleOutput
from cli.timer import timer
from cli.tps import TPSViewGenerator
from cli.writers.csv import CSVWriterFactory
from cli.writers.excel.excel_writer_rust import RustExcelWriter
from mpp import DataParserFactory, ViewAggregationLevel, MppApi
from mpp.core.api_args import ApiArgs, SystemInformation
from mpp.core.configuration_path_generator import ConfigurationPathGenerator
from mpp.core.types import (EventInfoDataFrameColumns as eidc, ConfigurationPaths, VerboseLevel, RawDataFrameColumns
as rdc)
from cli.compare import MppComparer


class MppRunner:

    def __init__(self, cli_args):
        self.__cli_args = copy.deepcopy(cli_args)
        self.__input_data_file_path = cli_args.input_data_file_path
        self.__mpp_api = None
        self.__parser = None
        self.__system_info = None
        self.__output_handler = None
        self.__configuration_file_paths = None
        self.__tps_views = []
        self.__partitions = []
        self.__first_sample = 0
        self.__last_sample = 0
        

    def initialize(self):
        self.__parser = DataParserFactory.create(self.__input_data_file_path,
                                                 self.__cli_args.frequency)
        self.__system_info = self.__parser.system_info
        self.__output_handler = OutputHandler(self.__cli_args.output_file_specifier, self.__cli_args.output_format)
        self.__validate_core_types()
        self.__get_configuration_file_paths()
        self.__update_core_filters()
        self.__get_partition_info()
        self.__initialize_mpp_api()

    @property
    def input_data_file_path(self):
        return self.__input_data_file_path

    @property
    def output_file_specifier(self):
        return self.__cli_args.output_file_specifier

    @property
    def parser(self):
        return self.__parser

    @property
    def num_unique_events(self):
        return len(set(self.__parser.event_info[eidc.NAME]))
    
    @property
    def first_sample(self):
        return self.__first_sample

    @property
    def last_sample(self):
        return self.__last_sample
    
    @property
    def partitions(self):
        return self.__partitions

    @property
    def is_parallel(self):
        return self._conditions_met_for_parallel_processing()

    @property
    def num_partitions(self):
        return len(self.__partitions)

    def generate_mpp_views(self, progress=None):
        self.initialize()
        if CliConsoleOutput.is_regular_verbosity(self.__cli_args.verbose):
            CliConsoleOutput.show_run_information(self.__mpp_api.ref_tsc)
        self.__mpp_api.process_partitions(self.partitions, self.__cli_args.parallel_cores, self.__cli_args.no_detail_views)
        self.__mpp_api.generate_summary_views(self.first_sample, self.last_sample, timer)
        self.generate_tps_views(self.__mpp_api.view_writer, self.__mpp_api.summary_views)
        if CliConsoleOutput.is_regular_verbosity(self.__cli_args.verbose):
            CliConsoleOutput.show_sample_count(self.first_sample, self.last_sample, self.num_unique_events,
                                               self.__mpp_api.total_metrics_derived)
        self.write_excel_output(self.__mpp_api.view_collection)
        if progress:
            progress.put(1)
        if not CliConsoleOutput.is_regular_verbosity(self.__cli_args.verbose):
            print('.', end='', flush=True)

    def get_api_args(self):
        
        system_information = SystemInformation(
            processor_features=self.__system_info.processor_features,
            system_features=self.__system_info.system_features,
            uncore_units=self.__system_info.uncore_units,
            ref_tsc=self.__system_info.ref_tsc,
            unique_core_types=self.__system_info.unique_core_types,
            has_modules=self.__system_info.has_modules,
            has_die=self.__system_info.has_die,
            qpi_link_speed=self.__system_info.qpi_link_speed
        )

        api_args = ApiArgs(
            system_information=system_information,
            event_info=self.__parser.event_info,
            event_reader=self.__parser.event_reader,
            collector=self.__parser.collector,
            retire_latency=self.__cli_args.retire_latency,
            aggregation_levels=self.__get_requested_aggregation_levels(),
            metric_file_map=self.__cli_args.metric_file_path,
            is_parallel=self.is_parallel,
            no_detail_views=self.__cli_args.no_detail_views,
            percentile=self.__cli_args.percentile,
            output_directory=self.__output_handler.output_directory,
            output_prefix=self.__output_handler.output_file_prefix,
            output_writers=self.__get_output_writers(),
            unit_filters=self.__cli_args.core_filter,
            verbose=self.__cli_args.verbose
        )
        return api_args

    def generate_tps_views(self, view_writer, summary_views):
        if self.__cli_args.transactions_per_second:
            with timer() as number_of_seconds:
                tps_generator = TPSViewGenerator(self.__cli_args.transactions_per_second)
                tps_summary_views = tps_generator.generate_summaries(summary_views)
                view_writer.write(list(tps_summary_views.values()), self.first_sample, self.last_sample)
                self.__tps_views = [value.attributes for value in list(tps_summary_views.values())]
                if CliConsoleOutput.is_regular_verbosity(self.__cli_args.verbose):
                    CliConsoleOutput.show_tps_views_generated(number_of_seconds)

    def write_excel_output(self, view_collection):
        if self.__output_handler.excel_file_name:
            include_details = not self.__cli_args.no_detail_views
            include_charts = True if self.__cli_args.chart_format_file_path else False
            self.__append_tps_views(view_collection)
            if CliConsoleOutput.is_regular_verbosity(self.__cli_args.verbose):
                CliConsoleOutput.show_excel_file_creation()
            if self.__can_write_excel_with_rust() and not include_charts:
                self.__write_excel_rust(include_charts, include_details, view_collection)
            else:
                self.__write_excel_python(view_collection)

    def __can_write_excel_with_rust(self):
        try:
            import pyrust_xlsxwriter
            if self.__cli_args.disable_rust:
                return False
        except ImportError:
            return False
        return True

    def __append_tps_views(self, view_collection):
        if self.__cli_args.transactions_per_second:
            view_collection.append_views(self.__tps_views)

    def __write_excel_python(self, view_collection):
        from cli.writers.excel import excel_writer_python
        excel_writer_python.write_csv_data_to_excel(self.__cli_args, view_collection, self.__output_handler)
        if CliConsoleOutput.is_regular_verbosity(self.__cli_args.verbose):
            CliConsoleOutput.show_excel_output_destination(self.__cli_args.output_file_specifier)

    def __write_excel_rust(self, include_charts, include_details, view_collection):
        excel_file = self.__output_handler.excel_file_with_path
        rxlsx = RustExcelWriter(self.__output_handler.output_directory, excel_file, include_details, include_charts)
        rxlsx.write_csv_to_excel(view_collection)

    def __validate_core_types(self):
        ValidMultiFileSpec.validate_multi_file_core_types(self.__cli_args.metric_file_path,
                                                          self.__system_info.unique_core_types)

    def __get_requested_aggregation_levels(self):
        requested_aggregation_levels = [ViewAggregationLevel.SYSTEM]  # Always generate the system views
        for agg_level in [self.__cli_args.socket_view, self.__cli_args.die_view, self.__cli_args.core_view, self.__cli_args.thread_view, self.__cli_args.uncore_view]:
            if agg_level:
                requested_aggregation_levels.append(agg_level)
        return requested_aggregation_levels

    def __get_configuration_file_paths(self):
        configuration_file_finder = ConfigurationPathGenerator(self.__system_info.configuration_file_paths,
                                                               [self.__cli_args.metric_file_path,
                                                                self.__cli_args.chart_format_file_path],
                                                                self.__system_info.unique_core_types,
                                                                self.__cli_args.verbose)
        self.__configuration_file_paths = configuration_file_finder.generate()
        self.__set_chart_metric_file_paths()

    def __update_core_filters(self):
        # TODO: consider refactor to a separate class
        if self.__cli_args.core_filter and self.__is_hybrid():
            original_core_filter = self.__cli_args.core_filter.copy()
            core_filter_units = set(original_core_filter[rdc.CORE])
            for core_type in self.__system_info.unique_core_types:
                core_type_units = set([key for key, value in self.__system_info.core_type_map.items() if
                                   value == core_type])
                filter_units = list(core_filter_units.intersection(core_type_units))
                self.__cli_args.core_filter.update({core_type: filter_units})
            self.__cli_args.core_filter.pop(rdc.CORE)

    def __is_hybrid(self):
        return len(self.__system_info.unique_core_types) > 1

    def __set_chart_metric_file_paths(self):
        if self.__cli_args.chart_format_file_path:
            self.__cli_args.chart_format_file_path = {core_type: self.__configuration_file_paths[core_type][
                ConfigurationPaths.CHART_PATH] for core_type in self.__configuration_file_paths if
                                                      self.__configuration_file_paths[core_type][
                                                          ConfigurationPaths.CHART_PATH]}
        self.__cli_args.metric_file_path = {
            core_type: self.__configuration_file_paths[core_type][ConfigurationPaths.METRIC_PATH]
            for core_type in self.__configuration_file_paths if
            self.__configuration_file_paths[core_type][ConfigurationPaths.METRIC_PATH]}

    def __get_output_writers(self):
        use_polars = self.is_parallel and not self.__cli_args.disable_polars
        csv_writer = CSVWriterFactory.create(use_polars, Path(self.__output_handler.output_directory))
        return [csv_writer]

    def __get_partition_info(self):
        self.__partitions = self.parser.partition(chunk_size=self.__cli_args.chunk_size, **get_sample_range_args(self.__cli_args))
        self.__first_sample = self.__partitions[0].first_sample
        self.__last_sample = self.__partitions[-1].last_sample

    def __initialize_mpp_api(self):
        self.__mpp_api = MppApi(self.get_api_args())
        self.__mpp_api.initialize()

    def _conditions_met_for_parallel_processing(self):
        return self.__cli_args.force_parallel or (self.__file_is_large_enough_for_parallel_processing() and
                                               self.num_partitions > 1 and self.__cli_args.parallel_cores != 1)

    def __file_is_large_enough_for_parallel_processing(self):
        return os.path.getsize(self.__input_data_file_path) / 1000000 > 6


class MppApplication:

    def __init__(self, cli_args):
        self.__cli_args = cli_args
        self.__mpp_runners: List[MppRunner] = []
        self.__output_file_specifier = None
        self.__output_file_specifiers = []
        self.__set_mpp_runners()

    @property
    def mpp_runners(self):
        try:
            return self.__mpp_runners.iterable
        except AttributeError:
            return self.__mpp_runners

    @property
    def output_file_specifiers(self):
        return self.__output_file_specifiers

    def run(self):
        if  len(self.__mpp_runners) == 1:
            self.__run_single_mpp_runner()
        else:
            self.__run_multiple_mpp_runners()
        if self.__cli_args.compare:
            self.__compare_files()

    def __compare_files(self):
        if self.__cli_args.baseline:
            self.__cli_args.baseline = os.path.splitext(self.__cli_args.baseline)[0]
        mpp_comparer = MppComparer(self.__output_file_specifiers, self.__cli_args.delta, self.__cli_args.baseline,
                                   self.__cli_args.compare)
        mpp_comparer.compare_files()

    @staticmethod
    def generate_mpp_views(mpp_runner, progress=None):
        try:
            mpp_runner.generate_mpp_views(progress)
            return 1
        except Exception as e:
            print(f'\nError processing file "{mpp_runner.input_data_file_path}": {e}')
            return 0

    def __adjust_verbosity_for_multi_file_runs(self, mpp_input_data_files):
        if self.__is_multi_file_run(mpp_input_data_files) and self.__is_default_logging_level():
            self.__cli_args.verbose = VerboseLevel.BATCH_INFO

    def __is_default_logging_level(self):
        return self.__cli_args.verbose == VerboseLevel.INFO

    @staticmethod
    def __is_multi_file_run(mpp_input_data_files):
        return len(mpp_input_data_files) > 1

    def __run_single_mpp_runner(self):
        self.__mpp_runners[0].generate_mpp_views()

    def __run_multiple_mpp_runners(self):
        if not self.__cli_args.parallel_cores:
            self.__cli_args.parallel_cores = mp.cpu_count()
        if self.__cli_args.parallel_cores_per_runner != 1:
            if not self.__cli_args.parallel_cores_per_runner:
                self.__cli_args.parallel_cores_per_runner = mp.cpu_count()
            parallel_cores = min(self.__cli_args.parallel_cores, self.__cli_args.parallel_cores_per_runner)
            print(f'Processing {len(self.__mpp_runners)} files in parallel with a maximum of {parallel_cores} '
                  f"core{'s' if parallel_cores > 1 else ''}...", end='', flush=True)
            with ProcessPoolExecutor(max_workers=parallel_cores) as executor:
                results = list(executor.map(self.generate_mpp_views, self.__mpp_runners))
                print(f'\nSuccessfully processed {sum(results)} out of {len(self.__mpp_runners)} files.')
                if self.__at_least_one_file_is_successfully_processed(results):
                    print(f'Output written to: {self.__cli_args.output_file_specifier}')
        else:
            self.__mpp_runners = tqdm(self.__mpp_runners)
            for idx, mpp_runner in enumerate(self.__mpp_runners):
                self.__mpp_runners.set_description(f'Processing file {idx + 1} out of {len(self.__mpp_runners)}: '
                                            f'\'{mpp_runner.input_data_file_path}\'')
                mpp_runner.generate_mpp_views()

    @staticmethod
    def __at_least_one_file_is_successfully_processed(results):
        return sum(results)

    def __set_mpp_runners(self):
        self.__output_file_specifier = self.__cli_args.output_file_specifier
        self.__input_data_file_path = self.__cli_args.input_data_file_path
        mpp_input_data_files = self.__get_input_data_files()
        print(f'Found {len(mpp_input_data_files)} file{"s" if len(mpp_input_data_files) > 1 else ""} to process.')
        self.__check_preconditions(mpp_input_data_files, self.__input_data_file_path)
        self.__adjust_verbosity_for_multi_file_runs(mpp_input_data_files)
        for input_data_file in mpp_input_data_files:
            self.__initialize_mpp_runner(input_data_file, mpp_input_data_files,)
            self.__cli_args.output_file_specifier = self.__output_file_specifier

    @staticmethod
    def __check_preconditions(mpp_input_data_files, input_data_file_path):
        if not mpp_input_data_files:
            raise NoFilesFoundError(input_data_file_path[DefaultFileIds.DIRECTORY])

    def __initialize_mpp_runner(self, input_data_file, mpp_input_data_files,):
        self.__cli_args.input_data_file_path = input_data_file
        self.__set_output_file_specifier(input_data_file, mpp_input_data_files)
        mpp_runner = MppRunner(self.__cli_args)
        self.__mpp_runners.append(mpp_runner)

    def __set_output_file_specifier(self, input_data_file, mpp_input_data_files):
        if self.__is_multi_file_run(mpp_input_data_files):
            relative_path = os.path.relpath(Path(input_data_file), self.__input_data_file_path[
                DefaultFileIds.DIRECTORY])
            relative_path = relative_path.replace(os.path.sep, '_')
            prefix = self.__handle_prefixes(relative_path)
            self.__cli_args.output_file_specifier = Path(self.__cli_args.output_file_specifier) / prefix
            self.__output_file_specifiers.append(self.__cli_args.output_file_specifier)

    def __handle_prefixes(self, relative_path):
        prefix = None
        user_prefix = self.__get_user_prefix()
        default_prefix = os.path.splitext(relative_path)[0]
        if user_prefix:
            prefix = f'{user_prefix}_{default_prefix}'
        return prefix if user_prefix else default_prefix

    def __get_user_prefix(self,):
        user_prefix = ''
        if self.__user_prefix_is_specified():
            user_prefix = Path(self.__cli_args.output_file_specifier).name
            self.__cli_args.output_file_specifier = Path(self.__cli_args.output_file_specifier).parent
        return user_prefix

    def __user_prefix_is_specified(self):
        return not os.path.exists(self.__cli_args.output_file_specifier) and os.path.exists(
            Path(self.__cli_args.output_file_specifier).parent)

    def __get_input_data_files(self):
        mpp_input_data_files = []
        if DefaultFileIds.DIRECTORY in self.__cli_args.input_data_file_path.keys():
            initial_directory = self.__cli_args.input_data_file_path[DefaultFileIds.DIRECTORY]
            input_pattern = self.__adjust_input_pattern_for_recursive_arg()
            mpp_input_data_files = glob.glob(os.path.join(initial_directory, input_pattern), recursive=self.__cli_args.recursive)
        elif DefaultFileIds.INPUT_DATA_FILE_PATH in self.__cli_args.input_data_file_path.keys():
            mpp_input_data_files = [self.__cli_args.input_data_file_path[DefaultFileIds.INPUT_DATA_FILE_PATH]]
        mpp_input_data_files = list(filter(lambda x: os.path.isfile(x), mpp_input_data_files))
        return mpp_input_data_files

    def __adjust_input_pattern_for_recursive_arg(self):
        input_pattern = self.__cli_args.input_pattern
        if self.__cli_args.recursive:
            input_pattern = '**/' + input_pattern
        return input_pattern


class NoFilesFoundError(Exception):

    def __init__(self, directory):
        message = f"No input files found in directory: {directory}\n(Current working directory is {os.getcwd()})"
        super().__init__(message)

