fix(factory.py): reduce ollama pt LOC < 50

This commit is contained in:
Krrish Dholakia 2025-03-14 21:10:05 -07:00
parent f18742edbd
commit d818530265

View file

@ -166,173 +166,108 @@ def convert_to_ollama_image(openai_image_url: str):
) )
def _handle_ollama_system_message(
messages: list, prompt: str, msg_i: int
) -> Tuple[str, int]:
system_content_str = ""
## MERGE CONSECUTIVE SYSTEM CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "system":
msg_content = convert_content_list_to_str(messages[msg_i])
system_content_str += msg_content
msg_i += 1
return system_content_str, msg_i
def ollama_pt( def ollama_pt(
model, messages model, messages
) -> Union[ ) -> Union[
str, OllamaVisionModelObject str, OllamaVisionModelObject
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template ]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model: user_message_types = {"user", "tool", "function"}
prompt = custom_prompt( msg_i = 0
role_dict={ images = []
"system": {"pre_message": "### System:\n", "post_message": "\n"}, prompt = ""
"user": { while msg_i < len(messages):
"pre_message": "### User:\n", init_msg_i = msg_i
"post_message": "\n", user_content_str = ""
}, ## MERGE CONSECUTIVE USER CONTENT ##
"assistant": { while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
"pre_message": "### Response:\n", msg_content = messages[msg_i].get("content")
"post_message": "\n", if msg_content:
}, if isinstance(msg_content, list):
}, for m in msg_content:
final_prompt_value="### Response:", if m.get("type", "") == "image_url":
messages=messages, if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
msg_i += 1
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
system_content_str, msg_i = _handle_ollama_system_message(
messages, prompt, msg_i
) )
else: if system_content_str:
user_message_types = {"user", "tool", "function"} prompt += f"### System:\n{system_content_str}\n\n"
msg_i = 0
images = []
prompt = ""
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
msg_i += 1 assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_content_str += convert_content_list_to_str(messages[msg_i])
msg_i += 1
if user_content_str: tool_calls = messages[msg_i].get("tool_calls")
prompt += f"### User:\n{user_content_str}\n\n" ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
system_content_str = "" ollama_tool_calls.append(
## MERGE CONSECUTIVE SYSTEM CONTENT ## {
while ( "id": call_id,
msg_i < len(messages) and messages[msg_i]["role"] == "system" "type": "function",
): "function": {
msg_content = messages[msg_i].get("content") "name": function_name,
if msg_content: "arguments": arguments,
if isinstance(msg_content, list): },
for m in msg_content: }
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
system_content_str += m["text"]
else:
# Tool message content will always be a string
system_content_str += msg_content
msg_i += 1
if system_content_str:
prompt += f"### System:\n{system_content_str}\n\n"
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "text":
assistant_content_str += m["text"]
elif isinstance(msg_content, str):
# Tool message content will always be a string
assistant_content_str += msg_content
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
) )
msg_i += 1 if ollama_tool_calls:
assistant_content_str += (
if assistant_content_str: f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
) )
# prompt = ""
# images = []
# for message in messages:
# if isinstance(message["content"], str):
# prompt += message["content"]
# elif isinstance(message["content"], list):
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
# for element in message["content"]:
# if isinstance(element, dict):
# if element["type"] == "text":
# prompt += element["text"]
# elif element["type"] == "image_url":
# base64_image = convert_to_ollama_image(
# element["image_url"]["url"]
# )
# images.append(base64_image)
# if "tool_calls" in message: msg_i += 1
# tool_calls = []
# for call in message["tool_calls"]: if assistant_content_str:
# call_id: str = call["id"] prompt += f"### Assistant:\n{assistant_content_str}\n\n"
# function_name: str = call["function"]["name"]
# arguments = json.loads(call["function"]["arguments"])
# tool_calls.append( if msg_i == init_msg_i: # prevent infinite loops
# { raise litellm.BadRequestError(
# "id": call_id, message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
# "type": "function", model=model,
# "function": {"name": function_name, "arguments": arguments}, llm_provider="ollama",
# } )
# )
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n" response_dict: OllamaVisionModelObject = {
"prompt": prompt,
"images": images,
}
# elif "tool_call_id" in message: return response_dict
# prompt += f"### User:\n{message['content']}\n\n"
return {"prompt": prompt, "images": images}
return prompt
def mistral_instruct_pt(messages): def mistral_instruct_pt(messages):