Understanding and Implementing Medprompt

-

We now perform alternative shuffling ensembling by shuffling the order of answer selections for every test query, creating multiple variants of the identical query. The LLM is then prompted with these variants, together with the corresponding few-shot exemplars, to generate reasoning steps and a solution for every variant. Finally, we perform a majority vote over the predictions from all variants and choose the ultimate prediction.

The code related to this implementation may be found at this github repo link.

We use the MedQA [6] dataset for implementing and evaluating Medprompt. We first define helper functions for parsing the jsonl files.

def write_jsonl_file(file_path, dict_list):
"""
Write an inventory of dictionaries to a JSON Lines file.

Args:
- file_path (str): The trail to the file where the info shall be written.
- dict_list (list): A listing of dictionaries to jot down to the file.
"""
with open(file_path, 'w') as file:
for dictionary in dict_list:
json_line = json.dumps(dictionary)
file.write(json_line + 'n')

def read_jsonl_file(file_path):
"""
Parses a JSONL (JSON Lines) file and returns an inventory of dictionaries.

Args:
file_path (str): The trail to the JSONL file to be read.

Returns:
list of dict: A listing where each element is a dictionary representing
a JSON object from the file.
"""
jsonl_lines = []
with open(file_path, 'r', encoding="utf-8") as file:
for line in file:
json_object = json.loads(line)
jsonl_lines.append(json_object)

return jsonl_lines

Implementing Self-Generated CoT

For our implementation, we utilize the training set from MedQA. We implement a zero-shot CoT prompt and process all of the training questions. We use GPT-4o in our implementation. For every query, we generate the CoT and the corresponding answer. We define a prompt which relies on the template provided within the Medprompt paper.

system_prompt = """You might be an authority medical skilled. You might be supplied with a medical query with multiple answer selections.
Your goal is to think through the query rigorously and explain your reasoning step-by-step before choosing the ultimate answer.
Respond only with the reasoning steps and answer as specified below.
Below is the format for every query and answer:

Input:
## Query: {{query}}
{{answer_choices}}

Output:
## Answer
(model generated chain of thought explanation)
Due to this fact, the reply is [final model answer (e.g. A,B,C,D)]"""

def build_few_shot_prompt(system_prompt, query, examples, include_cot=True):
"""
Builds the zero-shot prompt.

Args:
system_prompt (str): Task Instruction for the LLM
content (dict): The content for which to create a question, formatted as
required by `create_query`.

Returns:
list of dict: A listing of messages, including a system message defining
the duty and a user message with the input query.
"""
messages = [{"role": "system", "content": system_prompt}]

for elem in examples:
messages.append({"role": "user", "content": create_query(elem)})
if include_cot:
messages.append({"role": "assistant", "content": format_answer(elem["cot"], elem["answer_idx"])})
else:
answer_string = f"""## AnswernTherefore, the reply is {elem["answer_idx"]}"""
messages.append({"role": "assistant", "content": answer_string})

messages.append({"role": "user", "content": create_query(query)})
return messages

def get_response(messages, model_name, temperature = 0.0, max_tokens = 10):
"""
Obtains the responses/answers of the model through the chat-completions API.

Args:
messages (list of dict): The built messages provided to the API.
model_name (str): Name of the model to access through the API
temperature (float): A worth between 0 and 1 that controls the randomness of the output.
A temperature value of 0 ideally makes the model pick the more than likely token, making the outputs deterministic.
max_tokens (int): Maximum variety of tokens that the model should generate

Returns:
str: The response message content from the model.
"""
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.selections[0].message.content

We also define helper functions for parsing the reasoning and the ultimate answer option from the LLM response.

def matches_ans_option(s):
"""
Checks if the string starts with the particular pattern 'Due to this fact, the reply is [A-Z]'.

Args:
s (str): The string to be checked.

Returns:
bool: True if the string matches the pattern, False otherwise.
"""
return bool(re.match(r'^Due to this fact, the reply is [A-Z]', s))

def extract_ans_option(s):
"""
Extracts the reply option (a single capital letter) from the beginning of the string.

Args:
s (str): The string containing the reply pattern.

Returns:
str or None: The captured answer option if the pattern is found, otherwise None.
"""
match = re.search(r'^Due to this fact, the reply is ([A-Z])', s)
if match:
return match.group(1) # Returns the captured alphabet
return None

def matches_answer_start(s):
"""
Checks if the string starts with the markdown header '## Answer'.

Args:
s (str): The string to be checked.

Returns:
bool: True if the string starts with '## Answer', False otherwise.
"""
return s.startswith("## Answer")

def validate_response(s):
"""
Validates a multi-line string response that it starts with '## Answer' and ends with the reply pattern.

Args:
s (str): The multi-line string response to be validated.

Returns:
bool: True if the response is valid, False otherwise.
"""
file_content = s.split("n")

return matches_ans_option(file_content[-1]) and matches_answer_start(s)

def parse_answer(response):
"""
Parses a response that starts with '## Answer', extracting the reasoning and the reply alternative.

Args:
response (str): The multi-line string response containing the reply and reasoning.

Returns:
tuple: A tuple containing the extracted CoT reasoning and the reply alternative.
"""
split_response = response.split("n")
assert split_response[0] == "## Answer"
cot_reasoning = "n".join(split_response[1:-1]).strip()
ans_choice = extract_ans_option(split_response[-1])
return cot_reasoning, ans_choice

We now process the questions within the training set of MedQA. We obtain CoT responses and answers for all questions and store them to a folder.

train_data = read_jsonl_file("data/phrases_no_exclude_train.jsonl")

cot_responses = []
# os.mkdir("cot_responses")
existing_files = os.listdir("cot_responses/")

for idx, item in enumerate(tqdm(train_data)):
if str(idx) + ".txt" in existing_files:
proceed

prompt = build_zero_shot_prompt(system_prompt, item)
try:
response = get_response(prompt, model_name="gpt-4o", max_tokens=500)
cot_responses.append(response)
with open(os.path.join("cot_responses", str(idx) + ".txt"), "w", encoding="utf-8") as f:
f.write(response)
except Exception as e :
print(str(e))
cot_responses.append("")

We now iterate across all of the generated responses to ascertain in the event that they are valid and cling to the prediction format defined within the prompt. We discard responses that don’t conform to the required format. After that, we check the anticipated answers against the bottom truth for every query and only retain questions for which the anticipated answers match the bottom truth.

questions_dict = []
ctr = 0
for idx, query in enumerate(tqdm(train_data)):
file = open(os.path.join("cot_responses/", str(idx) + ".txt"), encoding="utf-8").read()
if not validate_response(file):
proceed

cot, pred_ans = parse_answer(file)

dict_elem = {}
dict_elem["idx"] = idx
dict_elem["question"] = query["question"]
dict_elem["answer"] = query["answer"]
dict_elem["options"] = query["options"]
dict_elem["cot"] = cot
dict_elem["pred_ans"] = pred_ans
questions_dict.append(dict_elem)

filtered_questions_dict = []
for item in tqdm(questions_dict):
pred_ans = item["options"][item["pred_ans"]]
if pred_ans == item["answer"]:
filtered_questions_dict.append(item)

Implementing the KNN model

Having processed the training set and obtained the CoT response for all these questions, we now embed all questions using the text-embedding-ada-002 from OpenAI.

def get_embedding(text, model="text-embedding-ada-002"):
return client.embeddings.create(input = [text], model=model).data[0].embedding

for item in tqdm(filtered_questions_dict):
item["embedding"] = get_embedding(item["question"])
inv_options_map = {v:k for k,v in item["options"].items()}
item["answer_idx"] = inv_options_map[item["answer"]]

We now train a KNN model using these query embeddings. This acts as a retriever at inference time, because it helps us to retrieve similar datapoints from the training set which are most just like the query from the test set.

import numpy as np
from sklearn.neighbors import NearestNeighbors

embeddings = np.array([d["embedding"] for d in filtered_questions_dict])
indices = list(range(len(filtered_questions_dict)))

knn = NearestNeighbors(n_neighbors=5, algorithm='auto', metric='cosine').fit(embeddings)

Implementing the Dynamic Few-Shot and Alternative Shuffling Ensemble Logic

We will now run inference. We subsample 500 questions from the MedQA test set for our evaluation. For every query, we retrieve the 5 most similar questions from the train set using the KNN module, together with their respective CoT reasoning steps and predicted answers. We construct a few-shot prompt using these examples.

For every query, we also shuffle the order of the choices 5 times to create different variants. We then utilize the constructed few-shot prompt to get the anticipated answer for every of the variants with shuffled options.

def shuffle_option_labels(answer_options):
"""
Shuffles the choices of the query.

Parameters:
answer_options (dict): A dictionary with the choices.

Returns:
dict: A brand new dictionary with the shuffled options.
"""
options = list(answer_options.values())
random.shuffle(options)
labels = [chr(i) for i in range(ord('A'), ord('A') + len(options))]
shuffled_options_dict = {label: option for label, option in zip(labels, options)}

return shuffled_options_dict

test_samples = read_jsonl_file("final_processed_test_set_responses_medprompt.jsonl")

for query in tqdm(test_samples, color ="green"):
question_variants = []
prompt_variants = []
cot_responses = []
question_embedding = get_embedding(query["question"])
distances, top_k_indices = knn.kneighbors([question_embedding], n_neighbors=5)
top_k_dicts = [filtered_questions_dict[i] for i in top_k_indices[0]]
query["outputs"] = []

for idx in range(5):
question_copy = query.copy()
shuffled_options = shuffle_option_labels(query["options"])
inv_map = {v:k for k,v in shuffled_options.items()}

question_copy["options"] = shuffled_options
question_copy["answer_idx"] = inv_map[question_copy["answer"]]
question_variants.append(question_copy)
prompt = build_few_shot_prompt(system_prompt, question_copy, top_k_dicts)
prompt_variants.append(prompt)

for prompt in tqdm(prompt_variants):
response = get_response(prompt, model_name="gpt-4o", max_tokens=500)
cot_responses.append(response)

for question_sample, answer in zip(question_variants, cot_responses):
if validate_response(answer):
cot, pred_ans = parse_answer(answer)

else:
cot = ""
pred_ans = ""

query["outputs"].append({"query": question_sample["question"], "options": question_sample["options"], "cot": cot, "pred_ans": question_sample["options"].get(pred_ans, "")})

We now evaluate the outcomes of Medprompt over the test set. For every query, we now have five predictions generated through the ensemble logic. We take the mode, or most steadily occurring prediction, for every query as the ultimate prediction and evaluate the performance. Two edge cases are possible here:

  1. Two different answer options are predicted two times each, with no clear winner.
  2. There may be an error with the response generated, meaning that we don’t have a predicted answer option.

For each of those edge cases, we consider the query to be wrongly answered by the LLM.

def find_mode_string_list(string_list):
"""
Finds probably the most steadily occurring strings.

Parameters:
string_list (list of str): A listing of strings.
Returns:
list of str or None: A listing containing probably the most frequent string(s) from the input list.
Returns None if the input list is empty.
"""
if not string_list:
return None

string_counts = Counter(string_list)
max_freq = max(string_counts.values())
mode_strings = [string for string, count in string_counts.items() if count == max_freq]
return mode_strings

ctr = 0
for item in test_samples:
pred_ans = [x["pred_ans"] for x in item["outputs"]]
freq_ans = find_mode_string_list(pred_ans)

if len(freq_ans) > 1:
final_prediction = ""
else:
final_prediction = freq_ans[0]

if final_prediction == item["answer"]:
ctr +=1

print(ctr / len(test_samples))

We evaluate the performance of Medprompt with GPT-4o when it comes to accuracy on the MedQA test subset. Moreover, we benchmark the performance of Zero-shot prompting, Random Few-Shot prompting, and Random Few-Shot with CoT prompting.

Results of our evaluation (Image by Creator)

We observe that Medprompt and Random Few-Shot CoT prompting outperform the Zero and Few-Shot prompting baselines. Nonetheless, surprisingly, we notice that Random Few-Shot CoT outperforms our Medprompt performance. This might be attributable to a few reasons:

  1. The unique Medprompt paper benchmarked the performance of GPT-4. We observe that GPT-4o outperforms GPT-4T and GPT-4 on various text benchmarks significantly (https://openai.com/index/hello-gpt-4o/), indicating that Medprompt could have a lesser effect on a stronger model like GPT-4o.
  2. We restrict our evaluation to 500 questions subsampled from MedQA. The Medprompt paper evaluates other Medical MCQA datasets and the complete version of MedQA. Evaluating GPT-4o on the whole versions of the datasets could give a greater picture of the general performance.

Medprompt is an interesting framework for creating sophisticated prompting pipelines, particularly for adapting a generalist LLM to a selected domain without the necessity for fine-tuning. It also highlights the considerations involved in deciding between prompting and fine-tuning for various use cases. Exploring how far prompting may be pushed to boost LLM performance is essential, because it offers a resource and cost-efficient alternative to fine-tuning.

[1] Nori, H., Lee, Y. T., Zhang, S., Carignan, D., Edgar, R., Fusi, N., … & Horvitz, E. (2023). Can generalist foundation models outcompete special-purpose tuning? case study in medicine. arXiv preprint arXiv:2311.16452. (https://arxiv.org/abs/2311.16452)

[2] Wei, J., Wang, X., Schuurmans, D., Bosma, M., Xia, F., Chi, E., … & Zhou, D. (2022). Chain-of-thought prompting elicits reasoning in large language models. Advances in Neural Information Processing Systems, 35, 24824–24837. (https://openreview.net/pdf?id=_VjQlMeSB_J)

[3] Gekhman, Z., Yona, G., Aharoni, R., Eyal, M., Feder, A., Reichart, R., & Herzig, J. (2024). Does Nice-Tuning LLMs on Recent Knowledge Encourage Hallucinations?. arXiv preprint arXiv:2405.05904. (https://arxiv.org/abs/2405.05904)

[4] Singhal, K., Azizi, S., Tu, T., Mahdavi, S. S., Wei, J., Chung, H. W., … & Natarajan, V. (2023). Large language models encode clinical knowledge. Nature, 620(7972), 172–180. (https://www.nature.com/articles/s41586-023-06291-2)

[5] Singhal, K., Tu, T., Gottweis, J., Sayres, R., Wulczyn, E., Hou, L., … & Natarajan, V. (2023). Towards expert-level medical query answering with large language models. arXiv preprint arXiv:2305.09617. (https://arxiv.org/abs/2305.09617)

[6] Jin, D., Pan, E., Oufattole, N., Weng, W. H., Fang, H., & Szolovits, P. (2021). What disease does this patient have? a large-scale open domain query answering dataset from medical exams. Applied Sciences, 11(14), 6421. (https://arxiv.org/abs/2009.13081) (Original dataset is released under a MIT License)

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x