In [1]:
import gradio as gr
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from safetensors import safe_open
from transformers import pipeline, AutoTokenizer

# Load trial spaces data
trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')

# Load embedding model
embedding_model = SentenceTransformer('reranker_round2.model', trust_remote_code=True, device='cuda')

# Load precomputed trial space embeddings
with safe_open("trial_space_embeddings.safetensors", framework="pt", device=0) as f:
    trial_space_embeddings = f.get_tensor("space_embeddings")

# Load checker pipeline
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
checker_pipe = pipeline('text-classification', './roberta-checker', tokenizer=tokenizer, 
                        truncation=True, padding='max_length', max_length=512, device='cuda')


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.62s/it]


In [11]:
import gradio as gr
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from safetensors import safe_open
from transformers import pipeline, AutoTokenizer

# We assume the following objects have already been loaded:
# trial_spaces (DataFrame), embedding_model (SentenceTransformer),
# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)

def match_clinical_trials(patient_summary: str):
    # Encode patient summary
    patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
    
    # Compute similarities
    similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
    
    # Pull top 10
    sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
    top_indices = sorted_indices[0:10].cpu().numpy()
    
    relevant_spaces = trial_spaces.iloc[top_indices].this_space
    relevant_nctid = trial_spaces.iloc[top_indices].nct_id
    relevant_title = trial_spaces.iloc[top_indices].title
    relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
    relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria

    analysis = pd.DataFrame({
        'patient_summary': patient_summary, 
        'this_space': relevant_spaces,
        'nct_id': relevant_nctid, 
        'trial_title': relevant_title,
        'trial_brief_summary': relevant_brief_summary, 
        'trial_eligibility_criteria': relevant_eligibility_criteria
    }).reset_index(drop=True)
    
    analysis['pt_trial_pair'] = analysis['this_space'] + "\nNow here is the patient summary:" + analysis['patient_summary']
    
    # Run checker pipeline
    classifier_results = checker_pipe(analysis.pt_trial_pair.tolist())
    analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
    analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
    
    # Return a subset of columns that are most relevant
    return analysis[[
        'nct_id', 
        'trial_title', 
        'trial_brief_summary', 
        'trial_eligibility_criteria', 
        'trial_checker_result', 
        'trial_checker_score'
    ]]

custom_css = """
#input_box textarea {
    width: 600px !important;
    height: 250px !important;
}

#output_df table {
    width: 100% !important;
    table-layout: auto !important;
    border-collapse: collapse !important;
}

#output_df table td, #output_df table th {
    min-width: 100px;
    overflow: hidden;
    text-overflow: ellipsis;
    white-space: nowrap;
    border: 1px solid #ccc;
    padding: 4px;
}
"""

# JavaScript for enabling colResizable
js_script = """
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/colresizable@1.6.0/colResizable-1.6.min.js"></script>
<script>
document.addEventListener('DOMContentLoaded', function() {
    var interval = setInterval(function() {
        var table = document.querySelector('#output_df table');
        if (table && typeof jQuery !== 'undefined' && typeof jQuery(table).colResizable === 'function') {
            jQuery('#output_df table').colResizable({liveDrag:true});
            clearInterval(interval);
        }
    }, 500);
});
</script>
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML("<h3>Clinical Trial Matcher</h3>")
    patient_summary_input = gr.Textbox(label="Enter Patient Summary", elem_id="input_box")
    submit_btn = gr.Button("Find Matches")
    output_df = gr.DataFrame(
        headers=[
            "nct_id", 
            "trial_title", 
            "trial_brief_summary", 
            "trial_eligibility_criteria", 
            "trial_checker_result", 
            "trial_checker_score"
        ], 
        elem_id="output_df"
    )

    submit_btn.click(fn=match_clinical_trials, 
                     inputs=patient_summary_input, 
                     outputs=output_df)
    
    gr.HTML(js_script)

demo.launch()


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [10]:
demo.close()

Closing server running on port: 7860
