mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(factory.py): reduce ollama pt LOC < 50
This commit is contained in:
parent
f18742edbd
commit
d818530265
1 changed files with 85 additions and 150 deletions
|
@ -166,173 +166,108 @@ def convert_to_ollama_image(openai_image_url: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_ollama_system_message(
|
||||||
|
messages: list, prompt: str, msg_i: int
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
system_content_str = ""
|
||||||
|
## MERGE CONSECUTIVE SYSTEM CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "system":
|
||||||
|
msg_content = convert_content_list_to_str(messages[msg_i])
|
||||||
|
system_content_str += msg_content
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
return system_content_str, msg_i
|
||||||
|
|
||||||
|
|
||||||
def ollama_pt(
|
def ollama_pt(
|
||||||
model, messages
|
model, messages
|
||||||
) -> Union[
|
) -> Union[
|
||||||
str, OllamaVisionModelObject
|
str, OllamaVisionModelObject
|
||||||
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||||
if "instruct" in model:
|
user_message_types = {"user", "tool", "function"}
|
||||||
prompt = custom_prompt(
|
msg_i = 0
|
||||||
role_dict={
|
images = []
|
||||||
"system": {"pre_message": "### System:\n", "post_message": "\n"},
|
prompt = ""
|
||||||
"user": {
|
while msg_i < len(messages):
|
||||||
"pre_message": "### User:\n",
|
init_msg_i = msg_i
|
||||||
"post_message": "\n",
|
user_content_str = ""
|
||||||
},
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
"assistant": {
|
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||||
"pre_message": "### Response:\n",
|
msg_content = messages[msg_i].get("content")
|
||||||
"post_message": "\n",
|
if msg_content:
|
||||||
},
|
if isinstance(msg_content, list):
|
||||||
},
|
for m in msg_content:
|
||||||
final_prompt_value="### Response:",
|
if m.get("type", "") == "image_url":
|
||||||
messages=messages,
|
if isinstance(m["image_url"], str):
|
||||||
|
images.append(m["image_url"])
|
||||||
|
elif isinstance(m["image_url"], dict):
|
||||||
|
images.append(m["image_url"]["url"])
|
||||||
|
elif m.get("type", "") == "text":
|
||||||
|
user_content_str += m["text"]
|
||||||
|
else:
|
||||||
|
# Tool message content will always be a string
|
||||||
|
user_content_str += msg_content
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if user_content_str:
|
||||||
|
prompt += f"### User:\n{user_content_str}\n\n"
|
||||||
|
|
||||||
|
system_content_str, msg_i = _handle_ollama_system_message(
|
||||||
|
messages, prompt, msg_i
|
||||||
)
|
)
|
||||||
else:
|
if system_content_str:
|
||||||
user_message_types = {"user", "tool", "function"}
|
prompt += f"### System:\n{system_content_str}\n\n"
|
||||||
msg_i = 0
|
|
||||||
images = []
|
|
||||||
prompt = ""
|
|
||||||
while msg_i < len(messages):
|
|
||||||
init_msg_i = msg_i
|
|
||||||
user_content_str = ""
|
|
||||||
## MERGE CONSECUTIVE USER CONTENT ##
|
|
||||||
while (
|
|
||||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
|
||||||
):
|
|
||||||
msg_content = messages[msg_i].get("content")
|
|
||||||
if msg_content:
|
|
||||||
if isinstance(msg_content, list):
|
|
||||||
for m in msg_content:
|
|
||||||
if m.get("type", "") == "image_url":
|
|
||||||
if isinstance(m["image_url"], str):
|
|
||||||
images.append(m["image_url"])
|
|
||||||
elif isinstance(m["image_url"], dict):
|
|
||||||
images.append(m["image_url"]["url"])
|
|
||||||
elif m.get("type", "") == "text":
|
|
||||||
user_content_str += m["text"]
|
|
||||||
else:
|
|
||||||
# Tool message content will always be a string
|
|
||||||
user_content_str += msg_content
|
|
||||||
|
|
||||||
msg_i += 1
|
assistant_content_str = ""
|
||||||
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
|
assistant_content_str += convert_content_list_to_str(messages[msg_i])
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
if user_content_str:
|
tool_calls = messages[msg_i].get("tool_calls")
|
||||||
prompt += f"### User:\n{user_content_str}\n\n"
|
ollama_tool_calls = []
|
||||||
|
if tool_calls:
|
||||||
|
for call in tool_calls:
|
||||||
|
call_id: str = call["id"]
|
||||||
|
function_name: str = call["function"]["name"]
|
||||||
|
arguments = json.loads(call["function"]["arguments"])
|
||||||
|
|
||||||
system_content_str = ""
|
ollama_tool_calls.append(
|
||||||
## MERGE CONSECUTIVE SYSTEM CONTENT ##
|
{
|
||||||
while (
|
"id": call_id,
|
||||||
msg_i < len(messages) and messages[msg_i]["role"] == "system"
|
"type": "function",
|
||||||
):
|
"function": {
|
||||||
msg_content = messages[msg_i].get("content")
|
"name": function_name,
|
||||||
if msg_content:
|
"arguments": arguments,
|
||||||
if isinstance(msg_content, list):
|
},
|
||||||
for m in msg_content:
|
}
|
||||||
if m.get("type", "") == "image_url":
|
|
||||||
if isinstance(m["image_url"], str):
|
|
||||||
images.append(m["image_url"])
|
|
||||||
elif isinstance(m["image_url"], dict):
|
|
||||||
images.append(m["image_url"]["url"])
|
|
||||||
elif m.get("type", "") == "text":
|
|
||||||
system_content_str += m["text"]
|
|
||||||
else:
|
|
||||||
# Tool message content will always be a string
|
|
||||||
system_content_str += msg_content
|
|
||||||
|
|
||||||
msg_i += 1
|
|
||||||
|
|
||||||
if system_content_str:
|
|
||||||
prompt += f"### System:\n{system_content_str}\n\n"
|
|
||||||
|
|
||||||
assistant_content_str = ""
|
|
||||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
|
||||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
|
||||||
msg_content = messages[msg_i].get("content")
|
|
||||||
if msg_content:
|
|
||||||
if isinstance(msg_content, list):
|
|
||||||
for m in msg_content:
|
|
||||||
if m.get("type", "") == "text":
|
|
||||||
assistant_content_str += m["text"]
|
|
||||||
elif isinstance(msg_content, str):
|
|
||||||
# Tool message content will always be a string
|
|
||||||
assistant_content_str += msg_content
|
|
||||||
|
|
||||||
tool_calls = messages[msg_i].get("tool_calls")
|
|
||||||
ollama_tool_calls = []
|
|
||||||
if tool_calls:
|
|
||||||
for call in tool_calls:
|
|
||||||
call_id: str = call["id"]
|
|
||||||
function_name: str = call["function"]["name"]
|
|
||||||
arguments = json.loads(call["function"]["arguments"])
|
|
||||||
|
|
||||||
ollama_tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": call_id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": function_name,
|
|
||||||
"arguments": arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if ollama_tool_calls:
|
|
||||||
assistant_content_str += (
|
|
||||||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
msg_i += 1
|
if ollama_tool_calls:
|
||||||
|
assistant_content_str += (
|
||||||
if assistant_content_str:
|
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||||
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
|
||||||
|
|
||||||
if msg_i == init_msg_i: # prevent infinite loops
|
|
||||||
raise litellm.BadRequestError(
|
|
||||||
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
|
||||||
model=model,
|
|
||||||
llm_provider="ollama",
|
|
||||||
)
|
)
|
||||||
# prompt = ""
|
|
||||||
# images = []
|
|
||||||
# for message in messages:
|
|
||||||
# if isinstance(message["content"], str):
|
|
||||||
# prompt += message["content"]
|
|
||||||
# elif isinstance(message["content"], list):
|
|
||||||
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
|
||||||
# for element in message["content"]:
|
|
||||||
# if isinstance(element, dict):
|
|
||||||
# if element["type"] == "text":
|
|
||||||
# prompt += element["text"]
|
|
||||||
# elif element["type"] == "image_url":
|
|
||||||
# base64_image = convert_to_ollama_image(
|
|
||||||
# element["image_url"]["url"]
|
|
||||||
# )
|
|
||||||
# images.append(base64_image)
|
|
||||||
|
|
||||||
# if "tool_calls" in message:
|
msg_i += 1
|
||||||
# tool_calls = []
|
|
||||||
|
|
||||||
# for call in message["tool_calls"]:
|
if assistant_content_str:
|
||||||
# call_id: str = call["id"]
|
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||||
# function_name: str = call["function"]["name"]
|
|
||||||
# arguments = json.loads(call["function"]["arguments"])
|
|
||||||
|
|
||||||
# tool_calls.append(
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
# {
|
raise litellm.BadRequestError(
|
||||||
# "id": call_id,
|
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||||
# "type": "function",
|
model=model,
|
||||||
# "function": {"name": function_name, "arguments": arguments},
|
llm_provider="ollama",
|
||||||
# }
|
)
|
||||||
# )
|
|
||||||
|
|
||||||
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
|
response_dict: OllamaVisionModelObject = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": images,
|
||||||
|
}
|
||||||
|
|
||||||
# elif "tool_call_id" in message:
|
return response_dict
|
||||||
# prompt += f"### User:\n{message['content']}\n\n"
|
|
||||||
|
|
||||||
return {"prompt": prompt, "images": images}
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def mistral_instruct_pt(messages):
|
def mistral_instruct_pt(messages):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue