# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
from datetime import datetime
from pathlib import Path
from typing import Any, Optional

import pandas as pd

from nsys_recipe import log
from nsys_recipe.data_service import DataService
from nsys_recipe.lib import heatmap, helpers, recipe
from nsys_recipe.lib.args import ArgumentParser, Option
from nsys_recipe.lib.collective_loader import ProfileInfo
from nsys_recipe.log import logger

DEFAULT_DOMAIN = 0  # Storage metrics use default domain ID.
THROUGHPUT_SCHEMA_NAME = "Throughput Counter Group Schema"
TRAFFIC_QUANTITY_SCHEMA_NAME = "Traffic Quantity Counter Group Schema"
LATENCY_SCHEMA_NAME = "Latency Counter Group Schema"


class StorageUtilMap(recipe.Recipe):
    @staticmethod
    def _mapper_func(
        report_path: str,
        max_duration: int,
        session_offset: int,
        bin_list: list[int],
        parsed_args: argparse.Namespace,
    ) -> Optional[tuple[str, Optional[pd.DataFrame], Optional[pd.DataFrame]]]:
        service = DataService(report_path, parsed_args)

        service.queue_table(
            "NVTX_COUNTER_GROUPS",
            ["domainId", "counterGroupId", "schemaId", "name", "scopeId"],
        )
        service.queue_table("NVTX_PAYLOAD_SCHEMAS", ["domainId", "schemaId", "name"])
        service.queue_table("NVTX_SCOPES", ["domainId", "scopeId", "path"])
        service.queue_table("TARGET_INFO_SYSTEM_ENV", ["name", "value"])

        df_dict = service.read_queued_tables()
        if df_dict is None:
            return None

        counter_groups_df = df_dict["NVTX_COUNTER_GROUPS"]
        payload_schemas_df = df_dict["NVTX_PAYLOAD_SCHEMAS"]
        scopes_df = df_dict["NVTX_SCOPES"]
        hostname_df = df_dict["TARGET_INFO_SYSTEM_ENV"]

        # Filter out non-throughput schemas.
        throughput_schema_names = [THROUGHPUT_SCHEMA_NAME, TRAFFIC_QUANTITY_SCHEMA_NAME]
        throughput_payload_schemas_df = payload_schemas_df[
            payload_schemas_df["domainId"] == DEFAULT_DOMAIN
        ][payload_schemas_df["name"].isin(throughput_schema_names)]

        # Filter out non-throughput counters.
        throughput_counter_groups_df = counter_groups_df.merge(
            throughput_payload_schemas_df[["domainId", "schemaId"]],
            on=["domainId", "schemaId"],
        ).drop(columns=["domainId", "schemaId"])

        # Filter out non-mounts scopes
        mounts_scopes_df = scopes_df[scopes_df["domainId"] == DEFAULT_DOMAIN][
            scopes_df["path"].str.startswith("Mounts/")
        ].drop(columns="domainId")

        # Create table that maps scopeId to volume name.
        scopes_to_volume_df = mounts_scopes_df.assign(
            volume=mounts_scopes_df["path"].apply(lambda x: x.split("/")[1])
        ).drop(columns="path")
        throughput_counter_groups_with_volume_df = throughput_counter_groups_df.merge(
            scopes_to_volume_df, on="scopeId", how="left"
        ).drop(columns="scopeId")

        # Filter out non-latency schemas.
        latency_payload_schemas_df = payload_schemas_df[
            payload_schemas_df["domainId"] == DEFAULT_DOMAIN
        ][payload_schemas_df["name"] == LATENCY_SCHEMA_NAME]

        # Filter out non-latency counters.
        latency_counter_groups_df = counter_groups_df.merge(
            latency_payload_schemas_df[["domainId", "schemaId"]],
            on=["domainId", "schemaId"],
        ).drop(columns=["domainId", "schemaId"])

        latency_counter_groups_with_volume_df = latency_counter_groups_df.merge(
            scopes_to_volume_df, on="scopeId", how="left"
        ).drop(columns="scopeId")

        # Exit gracefully if neither throughput nor latency data is available.
        if (
            throughput_counter_groups_with_volume_df.empty
            and latency_counter_groups_with_volume_df.empty
        ):
            logger.info(
                f"{report_path}: Report was successfully processed, but no data was found."
            )

            return None

        throughput_metrics = ["Read", "Write"]
        throughput_results: list[pd.DataFrame] = []

        latency_metrics = [
            "Read RPC queue",
            "Read RPC RTT",
            "Read RPC exe",
            "Write RPC queue",
            "Write RPC RTT",
            "Write RPC exe",
        ]
        latency_results: list[pd.DataFrame] = []

        data_index: dict[str, dict[str, Any]] = {
            "metrics": {"throughput": throughput_metrics, "latency": latency_metrics},
            "results": {"throughput": throughput_results, "latency": latency_results},
            "metrics_type": {
                "throughput": {"Read": "uint64", "Write": "uint64"},
                "latency": {
                    "Read RPC queue": "float64",
                    "Read RPC RTT": "float64",
                    "Read RPC exe": "float64",
                    "Write RPC queue": "float64",
                    "Write RPC RTT": "float64",
                    "Write RPC exe": "float64",
                },
            },
            "mean_type": {"throughput": "uint64", "latency": "float64"},
            "counters_dataframe": {
                "throughput": throughput_counter_groups_with_volume_df,
                "latency": latency_counter_groups_with_volume_df,
            },
        }
        # Construct the final throughput DataFrame
        for metric_family in ["throughput", "latency"]:
            # Get the samples tables.
            for counter_id in data_index["counters_dataframe"][metric_family][
                "counterGroupId"
            ]:
                service.queue_table(
                    f"NVTX_COUNTER_SAMPLES_{DEFAULT_DOMAIN}_{counter_id}",
                    ["timestamp"] + data_index["metrics"][metric_family],
                )

            df_dict = service.read_queued_tables()
            if df_dict is None:
                return None

            for table_name, samples_df in df_dict.items():
                assert type(table_name) is str
                counter_id = int(table_name.split("_")[-1])

                err_msg = service.filter_and_adjust_time(samples_df, session_offset)
                if err_msg is not None:
                    logger.error(f"{report_path}: {err_msg}")
                    return None
                samples_df = samples_df.astype(
                    data_index["metrics_type"][metric_family]
                )
                samples_df = samples_df[
                    (samples_df["timestamp"] >= session_offset)
                    & (samples_df["timestamp"] <= max_duration)
                ]
                if samples_df.empty:
                    continue

                samples_df["Duration"] = pd.cut(
                    samples_df["timestamp"], bin_list, labels=bin_list[:-1]
                )

                samples_df["Name"] = (
                    data_index["counters_dataframe"][metric_family]
                    .loc[
                        data_index["counters_dataframe"][metric_family][
                            "counterGroupId"
                        ]
                        == counter_id,
                        "name",
                    ]
                    .values[0]
                )

                # Physical drives do not have a `Name`, assign "Driver"
                samples_df["Name"] = samples_df["Name"].replace("", "Driver")
                samples_df["Name"] = samples_df["Name"].fillna("Driver")

                samples_df["Volume"] = (
                    data_index["counters_dataframe"][metric_family]
                    .loc[
                        data_index["counters_dataframe"][metric_family][
                            "counterGroupId"
                        ]
                        == counter_id,
                        "volume",
                    ]
                    .values[0]
                )

                # Add hostname column to dataframes
                samples_df["Hostname"] = hostname_df.loc[
                    hostname_df["name"] == "Hostname", "value"
                ].values[0]

                samples_df = (
                    samples_df.groupby(
                        ["Hostname", "Volume", "Name", "Duration"], observed=True
                    )[data_index["metrics"][metric_family]]
                    .mean()
                    .astype(data_index["mean_type"][metric_family])
                    .reset_index()
                )
                data_index["results"][metric_family].append(samples_df)

        return (
            Path(report_path).stem,
            pd.concat(throughput_results) if throughput_results else None,
            pd.concat(latency_results) if latency_results else None,
        )

    @log.time("Mapper")
    def mapper_func(
        self,
        context: recipe.Context,
        profile_info: tuple[list[str], list[int], list[int]],
    ) -> list[Optional[tuple[str, Optional[pd.DataFrame], Optional[pd.DataFrame]]]]:
        report_paths, max_durations, session_offsets = profile_info
        bin_size = heatmap.get_bin_size(self._parsed_args.bins, max(max_durations))
        bin_list = heatmap.generate_bin_list(
            self._parsed_args.bins, bin_size, include_last=True
        )

        return context.wait(
            context.map(
                self._mapper_func,
                report_paths,
                max_durations,
                session_offsets,
                bin_list=bin_list,
                parsed_args=self._parsed_args,
            )
        )

    @log.time("Reducer")
    def reducer_func(
        self,
        mapper_res: list[
            Optional[tuple[str, Optional[pd.DataFrame], Optional[pd.DataFrame]]]
        ],
    ) -> None:
        filtered_res = helpers.filter_none_or_empty(mapper_res)
        # Sort by file name.
        filtered_res = sorted(filtered_res, key=lambda x: x[0])
        filenames: list[str]
        throughput_analysis_dfs: list[Optional[pd.DataFrame]]
        latency_analysis_dfs: list[Optional[pd.DataFrame]]
        filenames, throughput_analysis_dfs, latency_analysis_dfs = map(
            list, zip(*filtered_res)
        )

        files_df = pd.DataFrame({"File": filenames}).rename_axis("Rank")
        files_df.to_parquet(self.add_output_file("files.parquet"))

        throughput_analysis_dfs = [
            df.assign(Rank=rank)
            for rank, df in enumerate(throughput_analysis_dfs)
            if df is not None
        ]
        throughput_analysis_df = (
            pd.concat(throughput_analysis_dfs)
            if throughput_analysis_dfs
            else pd.DataFrame()
        )
        throughput_analysis_df.to_parquet(
            self.add_output_file("throughput_analysis.parquet")
        )

        latency_analysis_dfs = [
            df.assign(Rank=rank)
            for rank, df in enumerate(latency_analysis_dfs)
            if df is not None
        ]
        latency_analysis_df = (
            pd.concat(latency_analysis_dfs) if latency_analysis_dfs else pd.DataFrame()
        )

        latency_analysis_df.to_parquet(self.add_output_file("latency_analysis.parquet"))

    def save_notebook(self) -> None:
        self.create_notebook(
            "heatmap.ipynb", replace_dict={"REPLACE_BIN": self._parsed_args.bins}
        )
        self.add_notebook_helper_file("nsys_display.py")

    def save_analysis_file(self) -> None:
        self._analysis_dict.update(
            {
                "EndTime": str(datetime.now()),
                "Outputs": self._output_files,
            }
        )
        self.create_analysis_file()

    def run(self, context: recipe.Context) -> None:
        super().run(context)

        profile_info = ProfileInfo.get_profile_info(context, self._parsed_args)
        mapper_res = self.mapper_func(context, profile_info)
        self.reducer_func(mapper_res)

        self.save_notebook()
        self.save_analysis_file()

    @classmethod
    def get_argument_parser(cls) -> ArgumentParser:
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.INPUT, required=True)
        parser.add_recipe_argument(Option.START)
        parser.add_recipe_argument(Option.END)
        parser.add_recipe_argument(Option.BINS)
        parser.add_recipe_argument(Option.DISABLE_ALIGNMENT)

        filter_group = parser.recipe_group.add_mutually_exclusive_group()
        parser.add_argument_to_group(filter_group, Option.FILTER_TIME)
        parser.add_argument_to_group(filter_group, Option.FILTER_NVTX)

        return parser
