# mypy: allow-untyped-defs
import logging
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from typing import Dict, Iterator, List, Set, Tuple

import torch
import torch.distributed as dist
import torch.distributed.fsdp._flat_param as flat_param_file
from torch.distributed.fsdp._common_utils import (
    _apply_to_modules,
    _get_module_fsdp_state,
    clean_tensor_name,
)


logger = logging.getLogger(__name__)


class SimpleProfiler:
    class Type(str, Enum):
        ALL = "all"
        ALLGATHER = "all_gather"
        ALLGATHER_OBJ = "all_gather_object"
        RESHARDING = "resharding"
        H2D = "H2D"
        D2H = "D2H"

    results: Dict[str, float] = defaultdict(float)
    profiling: Set[str] = set()

    @classmethod
    def reset(cls) -> None:
        cls.results.clear()
        cls.profiling.clear()

    @classmethod
    @contextmanager
    def profile(cls, profile_type: str) -> Iterator[None]:
        assert profile_type not in cls.profiling, (
            f"{profile_type} is already being profiled. "
            "SimpleProfiler does not support profiling multiple instances at "
            "the same time. "
        )

        cls.profiling.add(profile_type)
        begin = time.monotonic()
        try:
            yield
        finally:
            end = time.monotonic()
            cls.results[profile_type] += end - begin
            cls.profiling.remove(profile_type)

    @classmethod
    def dump_and_reset(cls, msg: str) -> None:
        # This cannot be combined with DETAIL distributed log
        # as the profiling will be very incorrect.
        if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO:
            logger.info("%s %s", msg, cls.results)
        cls.reset()


def _get_sharded_module_tree_with_module_name_to_fqns(
    model: torch.nn.Module,
) -> Tuple[str, Dict[str, List[str]]]:
    """
    It is used for composable fully_shard() code path, it returns
      1. sharded module tree info: each line reprents a submodule name that contats the
    submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
    the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
    level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
    is like this:
        [CompositeModel] FULLY SHARDED
            l1[Linear]
            u1[UnitModule] FULLY SHARDED
                u1.l1[Linear]
                u1.seq[Sequential]
                    u1.seq.0[ReLU]
                    u1.seq.1[Linear]
                    u1.seq.2[ReLU]
                u1.l2[Linear]
            u2[UnitModule] FULLY SHARDED
                u2.l1[Linear]
                u2.seq[Sequential]
                    u2.seq.0[ReLU]
                    u2.seq.1[Linear]
                    u2.seq.2[ReLU]
                u2.l2[Linear]
            l2[Linear]
      2. a dict mapping from the concated module FQN and class name to a list of its managed
    original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
            {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
             'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
             'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
            }
    All FQNs are prefixed starting from ``model``.

    Args:
        model (torch.nn.Module): Root module (which may or may not be passed to
                                 composable `fully_shard()`).
    """

    def module_fn(
        module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
    ):
        num_spaces = tree_level * 4
        trimed_prefix = (
            prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
        )
        prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
        printed_prefixed_module_name = " " * num_spaces + prefixed_module_name

        state = _get_module_fsdp_state(module)
        if state is None:
            sharded_tree_info[0] += printed_prefixed_module_name + "\n"
            return

        handle = state._fully_sharded_module_to_handle.get(module, None)

        if handle:
            sharded_tree_info[0] += (
                printed_prefixed_module_name + " FULLY SHARDED" + "\n"
            )
        else:
            sharded_tree_info[0] += printed_prefixed_module_name + "\n"

        if handle:
            param = handle.flat_param
            assert isinstance(param, flat_param_file.FlatParameter)
            global_fqns = [
                clean_tensor_name(prefix + name) for name in param._fqns
            ]  # prefixed from the top level `model` (i.e. including `prefix`)

            if prefixed_module_name in sharded_module_name_to_fqns:
                sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
            else:
                sharded_module_name_to_fqns[prefixed_module_name] = global_fqns

    def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
        return sharded_tree_info[0], sharded_module_name_to_fqns

    # Use List to mutate its value in place while running the recursive functions
    sharded_tree_info: List[str] = [
        "",
    ]
    sharded_module_name_to_fqns: Dict[str, List[str]] = {}
    return _apply_to_modules(
        model,
        module_fn,
        return_fn,
        [key for key, _ in model.named_parameters()],
        sharded_tree_info,
        sharded_module_name_to_fqns,
    )
