# mypy: allow-untyped-defs
"""This module converts objects into numpy array."""

import numpy as np

import torch


def make_np(x):
    """
    Convert an object into numpy array.

    Args:
      x: An instance of torch tensor

    Returns:
        numpy.array: Numpy array
    """
    if isinstance(x, np.ndarray):
        return x
    if np.isscalar(x):
        return np.array([x])
    if isinstance(x, torch.Tensor):
        return _prepare_pytorch(x)
    raise NotImplementedError(
        f"Got {type(x)}, but numpy array or torch tensor are expected."
    )


def _prepare_pytorch(x):
    if x.dtype == torch.bfloat16:
        x = x.to(torch.float16)
    x = x.detach().cpu().numpy()
    return x
