MyGene & BioMart REST APIs & gene standardization

import io
from typing import Iterable

import pandas as pd
from biothings_client import get_client

from .._normalize import NormalizeColumns


def format_into_dataframe(f):
    @wraps(f)
    def dataframe(*args, **kwargs) -> pd.DataFrame:
        # Check if the first argument is self
        idx = 0 if _is_function(args[0]) else 1

        if isinstance(args[idx], pd.DataFrame):
            df = args[idx]
            reformat = False
        else:
            df = pd.DataFrame(index=[d for d in args[idx]])
            reformat = True

        args_new = list(args)
        args_new[idx] = df
        return f(*args_new, _reformat=reformat, **kwargs)

    return dataframe


class Mygene:
    """Wrapper of MyGene.info.

    See: https://docs.mygene.info/en/latest/index.html
    """

    def __init__(self) -> None:
        self._server = get_client("gene")

    @property
    def server(self):
        """MyGene.info."""
        return self._server

    def query(
        self,
        genes: Iterable[str],
        scopes="symbol",
        fields="HGNC,symbol",
        species="human",
        as_dataframe=True,
        verbose=False,
        **kwargs,
    ):
        """Get HGNC IDs from Mygene.

        Args:
            genes: Input list
            scopes: ID types of the input
            fields: ID type of the output
            species: species
            as_dataframe: Whether to return a data frame
            verbose: Whether to print logging
            **kwargs: see **kwargs of `biothings_client.MyGeneInfo().querymany()`

        Returns:
            a dataframe ('HGNC' column is reformatted to be 'hgnc_id')
        """
        # query via mygene
        res = self.server.querymany(
            qterms=genes,
            scopes=scopes,
            fields=fields,
            species=species,
            as_dataframe=as_dataframe,
            verbose=verbose,
            **kwargs,
        )

        # format HGNC IDs to match `hgnc_id` format ('HGNC:int')
        if "HGNC" in res.columns:
            res["HGNC"] = [
                f"HGNC:{i}" if isinstance(i, str) else i for i in res["HGNC"]
            ]
        NormalizeColumns.gene(res)

        return res

    def _cleanup_mygene_returns(self, res: pd.DataFrame, unique_col="hgnc_id"):
        """Clean up duplicates and NAs from the mygene returns.

        Args:
            res: Returned dataframe from `.mg.query`
            unique_col: Unique identifier column

        Returns:
            a dict with only uniquely mapped IDs
        """
        mapped_dict = {}

        # drop columns without mapped unique IDs (HGNC)
        df = res.dropna(subset=unique_col)

        # for unique results, use returned HGNC IDs to get symbols from .hgnc
        udf = df[~df.index.duplicated(keep=False)].copy()
        df_ = self.df.reset_index().set_index("hgnc_id")
        udf["std_id"] = udf["hgnc_id"].map(
            df_.loc[df_.index.isin(udf["hgnc_id"]), ["hgnc_symbol"]].to_dict()[
                "hgnc_symbol"
            ]
        )
        mapped_dict.update(udf[["std_id"]].to_dict()["std_id"])

        # TODO: if the same HGNC ID is mapped to multiple inputs?
        if df[unique_col].duplicated().sum() > 0:
            pass

        # if a query is mapped to multiple HGNC IDs, do the reverse mapping from .hgnc
        # keep the shortest symbol as readthrough transcripts or pseudogenes are longer
        if df.index.duplicated().sum() > 0:
            dups = df[df.index.duplicated(keep=False)].copy()
            for dup in dups.index.unique():
                hids = dups[dups.index == dup][unique_col].tolist()
                df_ = self.df.reset_index().set_index("hgnc_id")
                d = df_.loc[df_.index.isin(hids), ["hgnc_symbol"]].to_dict()[
                    "hgnc_symbol"
                ]
                mapped_dict[dup] = pd.DataFrame.from_dict(d, orient="index")[0].min()

        return mapped_dict

    @format_into_dataframe
    def standardize(
        self,
        data,
        id_type: Optional[_IDs] = None,
        new_index: bool = True,
        _reformat: bool = False,
    ):
        """Index a dataframe with the official gene symbols from HGNC.

        Args:
            data: A list of gene symbols to be standardized
                If dataframe, will take the index
            id_type: Default is to consider input as gene symbols and alias
            new_index:
                If True, set the standardized symbols as the index
                    - unmapped will remain the original index
                    - original index stored in the `index_orig` column

                If False, write to the `standardized_symbol` column

        Returns:
            Replaces the DataFrame mappable index with the standardized symbols
            Adds a `std_id` column
            The original index is stored in the `index_orig` column
        """
        if id_type is None:
            mapped_dict = self._standardize_symbol(df=data)
        else:
            NotImplementedError
        data["std_id"] = data.index.map(mapped_dict)
        if new_index:
            data["index_orig"] = data.index
            data.index = data["std_id"].fillna(data["index_orig"])
            data.index.name = None

        if _reformat:
            return data

    def _standardize_symbol(
        self,
        df: pd.DataFrame,
    ):
        """Standardize gene symbols/aliases to symbol from `.reference` table.

        Args:
            df: A dataframe with index being the column to be standardized
            species: 'human'

        Returns:
            a dict with the standardized symbols
        """
        # 1. Mapping from symbol to hgnc_id using .hgnc table
        mapped_dict = self.df.loc[self.df.index.isin(df.index), ["hgnc_id"]].to_dict()[
            "hgnc_id"
        ]
        mapped_dict.update({k: k for k in mapped_dict.keys()})

        # 2. For not mapped symbols, map through alias
        notmapped = df[~df.index.isin(mapped_dict.keys())].copy()
        if notmapped.shape[0] > 0:
            mg = Mygene()
            res = mg.query(
                notmapped.index, scopes="symbol,alias", species=self._species
            )
            mapped_dict.update(self._cleanup_mygene_returns(res))

        return mapped_dict


class Biomart:
    """Wrapper of Biomart python APIs, good for accessing Ensembl data.

    See: https://github.com/sebriois/biomart
    """

    def __init__(self) -> None:
        try:
            import biomart

            self._server = biomart.BiomartServer("http://uswest.ensembl.org/biomart")
            self._dataset = None
        except ModuleNotFoundError:
            raise ModuleNotFoundError("Run `pip install biomart`")

    @property
    def server(self):
        """biomart.BiomartServer."""
        return self._server

    @property
    def databases(self):
        """Listing all databases."""
        return self._server.databases

    @property
    def datasets(self):
        """Listing all datasets."""
        return self._server.datasets

    @property
    def dataset(self):
        """A biomart.BiomartDataset."""
        return self._dataset

    def get_gene_ensembl(
        self,
        species="human",
        attributes=None,
        filters={},
        **kwargs,
    ):
        """Fetch the reference table of gene ensembl from biomart.

        Args:
            species: common name of species
            attributes: gene attributes from gene_ensembl datasets
            filters: see biomart.search()
            **kwargs: see biomart.search()
        """
        # database name
        from bionty.gene import Gene

        gn = Gene(species=species)
        sname = gn.species.search("short_name")
        self._dataset = self.datasets[f"{sname}_gene_ensembl"]

        # default is to get all the attributes
        attributes = gn.fields if attributes is None else attributes

        # Get the mapping between the attributes
        response = self.dataset.search(
            {"filters": filters, "attributes": attributes},
            **kwargs,
        )
        data = response.raw.data.decode("utf-8")

        # returns a dataframe
        df = pd.read_csv(io.StringIO(data), sep="\t", header=None)
        df.columns = attributes

        return df