diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 182f3618b8..ff366b2396 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -166,173 +166,108 @@ def convert_to_ollama_image(openai_image_url: str): ) +def _handle_ollama_system_message( + messages: list, prompt: str, msg_i: int +) -> Tuple[str, int]: + system_content_str = "" + ## MERGE CONSECUTIVE SYSTEM CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "system": + msg_content = convert_content_list_to_str(messages[msg_i]) + system_content_str += msg_content + + msg_i += 1 + + return system_content_str, msg_i + + def ollama_pt( model, messages ) -> Union[ str, OllamaVisionModelObject ]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template - if "instruct" in model: - prompt = custom_prompt( - role_dict={ - "system": {"pre_message": "### System:\n", "post_message": "\n"}, - "user": { - "pre_message": "### User:\n", - "post_message": "\n", - }, - "assistant": { - "pre_message": "### Response:\n", - "post_message": "\n", - }, - }, - final_prompt_value="### Response:", - messages=messages, + user_message_types = {"user", "tool", "function"} + msg_i = 0 + images = [] + prompt = "" + while msg_i < len(messages): + init_msg_i = msg_i + user_content_str = "" + ## MERGE CONSECUTIVE USER CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: + msg_content = messages[msg_i].get("content") + if msg_content: + if isinstance(msg_content, list): + for m in msg_content: + if m.get("type", "") == "image_url": + if isinstance(m["image_url"], str): + images.append(m["image_url"]) + elif isinstance(m["image_url"], dict): + images.append(m["image_url"]["url"]) + elif m.get("type", "") == "text": + user_content_str += m["text"] + else: + # Tool message content will always be a string + user_content_str += msg_content + + msg_i += 1 + + if user_content_str: + prompt += f"### User:\n{user_content_str}\n\n" + + system_content_str, msg_i = _handle_ollama_system_message( + messages, prompt, msg_i ) - else: - user_message_types = {"user", "tool", "function"} - msg_i = 0 - images = [] - prompt = "" - while msg_i < len(messages): - init_msg_i = msg_i - user_content_str = "" - ## MERGE CONSECUTIVE USER CONTENT ## - while ( - msg_i < len(messages) and messages[msg_i]["role"] in user_message_types - ): - msg_content = messages[msg_i].get("content") - if msg_content: - if isinstance(msg_content, list): - for m in msg_content: - if m.get("type", "") == "image_url": - if isinstance(m["image_url"], str): - images.append(m["image_url"]) - elif isinstance(m["image_url"], dict): - images.append(m["image_url"]["url"]) - elif m.get("type", "") == "text": - user_content_str += m["text"] - else: - # Tool message content will always be a string - user_content_str += msg_content + if system_content_str: + prompt += f"### System:\n{system_content_str}\n\n" - msg_i += 1 + assistant_content_str = "" + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + assistant_content_str += convert_content_list_to_str(messages[msg_i]) + msg_i += 1 - if user_content_str: - prompt += f"### User:\n{user_content_str}\n\n" + tool_calls = messages[msg_i].get("tool_calls") + ollama_tool_calls = [] + if tool_calls: + for call in tool_calls: + call_id: str = call["id"] + function_name: str = call["function"]["name"] + arguments = json.loads(call["function"]["arguments"]) - system_content_str = "" - ## MERGE CONSECUTIVE SYSTEM CONTENT ## - while ( - msg_i < len(messages) and messages[msg_i]["role"] == "system" - ): - msg_content = messages[msg_i].get("content") - if msg_content: - if isinstance(msg_content, list): - for m in msg_content: - if m.get("type", "") == "image_url": - if isinstance(m["image_url"], str): - images.append(m["image_url"]) - elif isinstance(m["image_url"], dict): - images.append(m["image_url"]["url"]) - elif m.get("type", "") == "text": - system_content_str += m["text"] - else: - # Tool message content will always be a string - system_content_str += msg_content - - msg_i += 1 - - if system_content_str: - prompt += f"### System:\n{system_content_str}\n\n" - - assistant_content_str = "" - ## MERGE CONSECUTIVE ASSISTANT CONTENT ## - while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - msg_content = messages[msg_i].get("content") - if msg_content: - if isinstance(msg_content, list): - for m in msg_content: - if m.get("type", "") == "text": - assistant_content_str += m["text"] - elif isinstance(msg_content, str): - # Tool message content will always be a string - assistant_content_str += msg_content - - tool_calls = messages[msg_i].get("tool_calls") - ollama_tool_calls = [] - if tool_calls: - for call in tool_calls: - call_id: str = call["id"] - function_name: str = call["function"]["name"] - arguments = json.loads(call["function"]["arguments"]) - - ollama_tool_calls.append( - { - "id": call_id, - "type": "function", - "function": { - "name": function_name, - "arguments": arguments, - }, - } - ) - - if ollama_tool_calls: - assistant_content_str += ( - f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}" + ollama_tool_calls.append( + { + "id": call_id, + "type": "function", + "function": { + "name": function_name, + "arguments": arguments, + }, + } ) - msg_i += 1 - - if assistant_content_str: - prompt += f"### Assistant:\n{assistant_content_str}\n\n" - - if msg_i == init_msg_i: # prevent infinite loops - raise litellm.BadRequestError( - message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}", - model=model, - llm_provider="ollama", + if ollama_tool_calls: + assistant_content_str += ( + f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}" ) - # prompt = "" - # images = [] - # for message in messages: - # if isinstance(message["content"], str): - # prompt += message["content"] - # elif isinstance(message["content"], list): - # # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models - # for element in message["content"]: - # if isinstance(element, dict): - # if element["type"] == "text": - # prompt += element["text"] - # elif element["type"] == "image_url": - # base64_image = convert_to_ollama_image( - # element["image_url"]["url"] - # ) - # images.append(base64_image) - # if "tool_calls" in message: - # tool_calls = [] + msg_i += 1 - # for call in message["tool_calls"]: - # call_id: str = call["id"] - # function_name: str = call["function"]["name"] - # arguments = json.loads(call["function"]["arguments"]) + if assistant_content_str: + prompt += f"### Assistant:\n{assistant_content_str}\n\n" - # tool_calls.append( - # { - # "id": call_id, - # "type": "function", - # "function": {"name": function_name, "arguments": arguments}, - # } - # ) + if msg_i == init_msg_i: # prevent infinite loops + raise litellm.BadRequestError( + message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}", + model=model, + llm_provider="ollama", + ) - # prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n" + response_dict: OllamaVisionModelObject = { + "prompt": prompt, + "images": images, + } - # elif "tool_call_id" in message: - # prompt += f"### User:\n{message['content']}\n\n" - - return {"prompt": prompt, "images": images} - - return prompt + return response_dict def mistral_instruct_pt(messages):