import gradio as gr import pandas as pd import numpy as np import umap import json import matplotlib.pyplot as plt import os import scanpy as sc import subprocess import sys from io import BytesIO from sklearn.linear_model import LogisticRegression from huggingface_hub import hf_hub_download def load_and_predict_with_classifier(x, model_path, output_path, save): # Load the model parameters from the JSON file with open(model_path, 'r') as f: model_params = json.load(f) # Reconstruct the logistic regression model model_loaded = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000) model_loaded.coef_ = np.array(model_params["coef"]) model_loaded.intercept_ = np.array(model_params["intercept"]) model_loaded.classes_ = np.array(model_params["classes"]) # output predictions y_pred = model_loaded.predict(x) # Convert the array to a Pandas DataFrame if save: df = pd.DataFrame(y_pred, columns=["predicted_cell_type"]) df.to_csv(output_path, index=False, header=False) return y_pred def plot_umap(adata): labels = pd.Categorical(adata.obs["cell_type"]) reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) embedding = reducer.fit_transform(adata.obsm["X_uce"]) plt.figure(figsize=(10, 8)) # Create the scatter plot scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6) # Create a legend handles = [] for i, cell_type in enumerate(labels.categories): handles.append(plt.Line2D([0], [0], marker='o', color='w', label=cell_type, markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)) plt.legend(handles=handles, title='Cell Type') plt.title('UMAP projection of the data') plt.xlabel('UMAP1') plt.ylabel('UMAP2') # Save plot to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) # Read the image from BytesIO object img = plt.imread(buf, format='png') return img def toggle_file_input(default_dataset): if default_dataset != "None": return gr.update(interactive=False) # Disable the file input if a default dataset is selected else: return gr.update(interactive=True) # Enable the file input if no default dataset is selected def clone_repo(): os.system('git clone https://github.com/minwoosun/UCE.git') def main(input_file_path, species, default_dataset, default_dataset_1_path, default_dataset_2_path): BASE_PATH = '/home/user/app/UCE/' os.chdir(BASE_PATH) sys.path.append(BASE_PATH) # Set default dataset path default_dataset_1_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad") default_dataset_2_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad") # If the user selects a default dataset, use that instead of the uploaded file if default_dataset == "PBMC 100 cells": input_file_path = default_dataset_1_path elif default_dataset == "PBMC 1000 cells": input_file_path = default_dataset_2_path ############## # UCE # ############## from evaluate import AnndataProcessor from accelerate import Accelerator model_loc = 'minwoosun/uce-100m' # Construct the command command = [ 'python', BASE_PATH + 'eval_single_anndata.py', '--adata_path', input_file_path, '--dir', BASE_PATH, '--model_loc', model_loc ] # Print the command for debugging print("Running command:", command) print("---> RUNNING UCE") result = subprocess.run(command, capture_output=True, text=True, check=True) print(result.stdout) print(result.stderr) print("---> FINSIH UCE") ################################ # Cell-type classification # ################################ # Set output file path file_name_with_ext = os.path.basename(input_file_path) file_name = os.path.splitext(file_name_with_ext)[0] pred_file = BASE_PATH + f"{file_name}_predictions.csv" model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json") file_name_with_ext = os.path.basename(input_file_path) file_name = os.path.splitext(file_name_with_ext)[0] output_file = BASE_PATH + f"{file_name}_uce_adata.h5ad" adata = sc.read_h5ad(output_file) x = adata.obsm['X_uce'] y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True) ############## # UMAP # ############## img = plot_umap(adata) return img, output_file, pred_file if __name__ == "__main__": BASE_PATH = '/home/user/app/UCE/' clone_repo() with gr.Blocks() as demo: gr.Markdown( '''
UCE 100M Demo šŸ¦ 
Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!
GitHub Paper Open In Colab
Upload a `.h5ad` single cell gene expression file and select the species (Human/Mouse). The demo will generate UMAP projections of the embeddings and allow you to download the embeddings for further analysis.
  1. 1. Upload your `.h5ad` file or select one of the default datasets (subset of 10x pbmc data)
  2. 2. Select the species
  3. 3. Click "Run" to view the UMAP scatter plot
  4. 4. Download the UCE embeddings and predicted cell-types
Please consider citing the following paper if you use this tool in your research:
Rosen, Y., Roohani, Y., Agarwal, A., Samotorčan, L., Tabula Sapiens Consortium, Quake, S. R., & Leskovec, J. Universal Cell Embeddings: A Foundation Model for Cell Biology. bioRxiv. https://doi.org/10.1101/2023.11.28.568918
''' ) # Define Gradio inputs and outputs file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below") # species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species") with gr.Row(): species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species") default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset") # Attach the `change` event to the dropdown default_dataset_input.change( toggle_file_input, inputs=[default_dataset_input], outputs=[file_input] ) run_button = gr.Button("Run", elem_classes="run-button") # Arrange UMAP plot and file output side by side with gr.Row(): image_output = gr.Image(type="numpy", label="UMAP_of_UCE_Embeddings") file_output = gr.File(label="Download embeddings") pred_output = gr.File(label="Download predictions") # Add the components and link to the function run_button.click( fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output] ) # # Examples section # examples = [ # ["", "human", "PBMC 100 cells"], # ["", "human", "PBMC 1000 cells"] # ] # gr.Examples( # fn=main, # examples=examples, # inputs=[file_input, species_input, default_dataset_input], # outputs=[image_output, file_output, pred_output], # cache_examples=True # ) demo.launch()