Source code for micom.viz.prediction

"""Visualization for phenotype prediction."""

from datetime import datetime
from micom.viz import Visualization
from micom.logger import logger
import json
import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from sklearn.model_selection import (
    cross_val_predict,
    cross_val_score,
    LeaveOneOut,
)
from sklearn.linear_model import (
    LogisticRegressionCV,
    LassoCV,
    LogisticRegression,
    Lasso,
)
from sklearn.preprocessing import StandardScaler


[docs]def plot_fit( results, phenotype, variable_type="binary", variable_name="phenotype", filename="fit_%s.html" % datetime.now().strftime("%Y%m%d"), flux_type="production", min_coef=0.001, atol=1e-6 ): """Test for differential metabolite production. This will fit the `phenotype` response using L1-regularized linear models with log-fluxes as features. Will use LASSO regression for a continuous response and L1-regularized Logistic regression for a binary response. Parameters ---------- results : micom.workflows.GrowthResults The results returned by the `grow` workflow. phenotype : pandas.Series The data to be fitted. Its index must correspond to `sample_id` in `exchanges`. variable_type : str of ["binary", "continuous"] The type of the variable. variable_name : str A short description of the phenotype for instance "disease_status". filename : str The HTML file where the visualization will be saved. flux_type : str of ["import", "production"] Whether to fit using import or production fluxes. min_coef : float in [0.0, Inf] Only report coefficient that are at least that large. atol : float Tolerance to consider a flux different from zero. Should be roughly equivalent to the solver tolerance. Returns ------- Visualization A MICOM visualization. Can be served with `viz.view`. """ exchanges = results.exchanges anns = results.annotations anns.index = anns.metabolite if flux_type == "import": exchanges = exchanges[ (exchanges.taxon == "medium") & (exchanges.direction == "import") ] exchanges["flux"] = exchanges.flux.abs() else: exchanges = exchanges[ (exchanges.taxon != "medium") & (exchanges.direction == "export") ] exchanges = ( exchanges.groupby(["reaction", "metabolite", "sample_id"]) .apply( lambda df: pd.Series( {"flux": sum(df.abundance * df.flux.abs())} ) ) .reset_index() ) exchanges = exchanges.loc[exchanges.flux > atol] if exchanges.shape[1] < 1: raise ValueError("None of the fluxes passed the tolerance threshold :(") if variable_type == "binary" and phenotype.nunique() != 2: raise ValueError( "Binary variables must have exactly two unique values, yours " "has: %s." % ", ".join(phenotype.unique()) ) elif variable_type == "continuous" and not is_numeric_dtype(phenotype): raise ValueError( "Continuous variables must have a numeric type, but yours is" " of type `%s`." % phenotype.dtype ) elif variable_type not in ["binary", "continuous"]: raise ValueError( "Unsupported variable type. Must be either `binary` or " "`continuous`." ) fluxes = exchanges.pivot_table( index="sample_id", columns="metabolite", values="flux", fill_value=atol ) fluxes = fluxes.applymap(np.log) meta = phenotype[fluxes.index] stds = fluxes.std(axis=1) bad = stds < 1e-6 if bad.any(): logger.warning("Removing %d fluxes due to zero variance." % bad.sum()) fluxes = fluxes.loc[:, ~bad] scaled = StandardScaler().fit_transform(fluxes) if variable_type == "binary": model = LogisticRegressionCV( penalty="l1", scoring="accuracy", solver="liblinear", cv=2, Cs=np.power(10.0, np.arange(-6, 6, 0.5)), max_iter=50000, ) fit = model.fit(scaled, meta) model = LogisticRegression( penalty="l1", solver="liblinear", C=fit.C_[0], max_iter=10000, ) fit = model.fit(scaled, meta) score = cross_val_score(model, X=scaled, y=meta, cv=LeaveOneOut()) coefs = pd.DataFrame( {"coef": fit.coef_[0, :], "metabolite": fluxes.columns} ) else: model = LassoCV(cv=2, max_iter=50000) fit = model.fit(scaled, meta) model = Lasso(alpha=fit.alpha_, max_iter=50000) fit = model.fit(scaled, meta) score = cross_val_score(model, X=scaled, y=meta, cv=3) coefs = pd.DataFrame({"coef": fit.coef_, "metabolite": fluxes.columns}) coefs["description"] = anns.loc[coefs.metabolite, "name"].values score = [np.mean(score), np.std(score)] score.append(model.score(scaled, meta)) if all(coefs.coef.abs() < min_coef): raise RuntimeError( "Unfortunately no metabolite flux was predictive for the " "chosen phenotype and a cutoff of %g :(" % min_coef ) data = {"fluxes": exchanges, "coefficients": coefs} coefs = coefs[coefs.coef.abs() >= min_coef].sort_values(by="coef") predicted = cross_val_predict(model, scaled, meta, cv=LeaveOneOut()) fitted = pd.DataFrame( {"real": meta, "predicted": predicted}, index=meta.index ) exchanges = exchanges.loc[ exchanges.metabolite.isin(coefs.metabolite.values) ].copy() exchanges["meta"] = meta[exchanges.sample_id].values exchanges["description"] = anns.loc[exchanges.metabolite, "name"].values var_type = "nominal" if variable_type == "binary" else "quantitative" viz = Visualization(filename, data, "tests.html") viz.save( fitted=fitted.to_json(orient="records"), coefs=coefs.to_json(orient="records"), exchanges=exchanges.to_json(orient="records"), metabolites=json.dumps(coefs.metabolite.tolist()), variable=variable_name, type=var_type, score=score, width=400, height=300, cheight=max(2 * coefs.shape[0], 40), cwidth=max(8 * coefs.shape[0], 160), ) return viz