From cceb7b59db42d0723dfb929883ed698f09f793d0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 4 Jul 2024 11:13:07 -0700 Subject: [PATCH] fix(cohere.py): fix message parsing to handle tool calling correctly --- litellm/llms/cohere_chat.py | 36 ++- litellm/llms/prompt_templates/factory.py | 277 +++++++++++++++++- .../tests/test_amazing_vertex_completion.py | 11 +- litellm/tests/test_completion.py | 91 ++++++ litellm/types/llms/cohere.py | 46 +++ 5 files changed, 426 insertions(+), 35 deletions(-) create mode 100644 litellm/types/llms/cohere.py diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py index 8ae8392438..1b3aa8405d 100644 --- a/litellm/llms/cohere_chat.py +++ b/litellm/llms/cohere_chat.py @@ -1,13 +1,19 @@ -import os, types import json +import os +import time +import traceback +import types from enum import Enum -import requests # type: ignore -import time, traceback from typing import Callable, Optional -from litellm.utils import ModelResponse, Choices, Message, Usage -import litellm + import httpx # type: ignore -from .prompt_templates.factory import cohere_message_pt +import requests # type: ignore + +import litellm +from litellm.types.llms.cohere import ToolResultObject +from litellm.utils import Choices, Message, ModelResponse, Usage + +from .prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2 class CohereError(Exception): @@ -112,7 +118,7 @@ class CohereChatConfig: def validate_environment(api_key): headers = { - "Request-Source":"unspecified:litellm", + "Request-Source": "unspecified:litellm", "accept": "application/json", "content-type": "application/json", } @@ -196,17 +202,17 @@ def completion( api_base: str, model_response: ModelResponse, print_verbose: Callable, + optional_params: dict, encoding, api_key, logging_obj, - optional_params=None, litellm_params=None, logger_fn=None, ): headers = validate_environment(api_key) completion_url = api_base model = model - prompt, tool_results = cohere_message_pt(messages=messages) + most_recent_message, chat_history = cohere_messages_pt_v2(messages=messages) ## Load Config config = litellm.CohereConfig.get_config() @@ -221,18 +227,18 @@ def completion( _is_function_call = True cohere_tools = construct_cohere_tool(tools=optional_params["tools"]) optional_params["tools"] = cohere_tools - if len(tool_results) > 0: - optional_params["tool_results"] = tool_results - + if isinstance(most_recent_message, dict): + optional_params["tool_results"] = [most_recent_message] + elif isinstance(most_recent_message, str): + optional_params["message"] = most_recent_message data = { "model": model, - "message": prompt, **optional_params, } ## LOGGING logging_obj.pre_call( - input=prompt, + input=most_recent_message, api_key=api_key, additional_args={ "complete_input_dict": data, @@ -256,7 +262,7 @@ def completion( else: ## LOGGING logging_obj.post_call( - input=prompt, + input=most_recent_message, api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 1557e715f1..e3f0ff4e82 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1415,16 +1415,37 @@ def convert_to_documents( return documents -def convert_openai_message_to_cohere_tool_result(message): +from litellm.types.llms.cohere import ( + CallObject, + ChatHistory, + ChatHistoryChatBot, + ChatHistorySystem, + ChatHistoryToolResult, + ChatHistoryUser, + ToolCallObject, + ToolResultObject, +) + + +def convert_openai_message_to_cohere_tool_result( + message, tool_calls: List +) -> ToolResultObject: """ OpenAI message with a tool result looks like: { "tool_call_id": "tool_1", "role": "tool", - "name": "get_current_weather", "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"}, }, """ + """ + OpenAI message with a function call looks like: + { + "role": "function", + "name": "get_current_weather", + "content": "function result goes here", + } + """ """ Cohere tool_results look like: @@ -1434,7 +1455,6 @@ def convert_openai_message_to_cohere_tool_result(message): "parameters": { "day": "2023-09-29" }, - "generation_id": "4807c924-9003-4d6b-8069-eda03962c465" }, "outputs": [ { @@ -1444,30 +1464,255 @@ def convert_openai_message_to_cohere_tool_result(message): ] }, """ + content_str: str = message.get("content", "") + if len(content_str) > 0: + try: + content = json.loads(content_str) + except json.JSONDecodeError: + content = {"result": content_str} + else: + content = {} + name = "" + arguments = {} + # Recover name from last message with tool calls + if len(tool_calls) > 0: + tools = tool_calls + msg_tool_call_id = message.get("tool_call_id", None) + for tool in tools: + prev_tool_call_id = tool.get("id", None) + if ( + msg_tool_call_id + and prev_tool_call_id + and msg_tool_call_id == prev_tool_call_id + ): + name = tool.get("function", {}).get("name", "") + arguments_str = tool.get("function", {}).get("arguments", "") + if arguments_str is not None and len(arguments_str) > 0: + arguments = json.loads(arguments_str) - tool_call_id = message.get("tool_call_id") - name = message.get("name") - content = message.get("content") + if message["role"] == "function": + name = message.get("name") + cohere_tool_result: ToolResultObject = { + "call": CallObject(name=name, parameters=arguments), + "outputs": [content], + } + return cohere_tool_result + else: + # We can't determine from openai message format whether it's a successful or + # error call result so default to the successful result template - # Create the Cohere tool_result dictionary - cohere_tool_result = { - "call": { - "name": name, - "parameters": {"location": "San Francisco, CA"}, - "generation_id": tool_call_id, - }, - "outputs": convert_to_documents(content), + cohere_tool_result = { + "call": CallObject(name=name, parameters=arguments), + "outputs": [content], + } + return cohere_tool_result + + +def get_all_tool_calls(messages: List) -> List: + """ + Returns extracted list of `tool_calls`. + + Done to handle openai no longer returning tool call 'name' in tool results. + """ + tool_calls: List = [] + for m in messages: + if m.get("tool_calls", None) is not None: + if isinstance(m["tool_calls"], list): + tool_calls.extend(m["tool_calls"]) + + return tool_calls + + +def convert_to_cohere_tool_invoke(tool_calls: list) -> List[ToolCallObject]: + """ + OpenAI tool invokes: + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + """ + + """ + Cohere tool invokes: + { + "role": "CHATBOT", + "tool_calls": [{"name": "get_weather", "parameters": {"location": "San Francisco, CA"}}] } - return cohere_tool_result + """ + + cohere_tool_invoke: List[ToolCallObject] = [ + { + "name": get_attribute_or_key( + get_attribute_or_key(tool, "function"), "name" + ), + "parameters": json.loads( + get_attribute_or_key( + get_attribute_or_key(tool, "function"), "arguments" + ) + ), + } + for tool in tool_calls + if get_attribute_or_key(tool, "type") == "function" + ] + + return cohere_tool_invoke + + +def cohere_messages_pt_v2( + messages: List, +) -> Tuple[Union[str, ToolResultObject], ChatHistory]: + """ + Returns a tuple(Union[tool_result, message], chat_history) + + - if last message is tool result -> return 'tool_result' + - if last message is text -> return message (str) + + - return preceding messages as 'chat_history' + + Note: + - cannot specify message if the last entry in chat history contains tool results + - message must be at least 1 token long or tool results must be specified. + """ + tool_calls: List = get_all_tool_calls(messages=messages) + + ## GET MOST RECENT MESSAGE + most_recent_message = messages.pop(-1) + returned_message: Union[ToolResultObject, str] = "" + if ( + most_recent_message.get("role", "") is not None + and most_recent_message["role"] == "tool" + ): + # tool result + returned_message = convert_openai_message_to_cohere_tool_result( + most_recent_message, tool_calls + ) + else: + content: Union[str, List] = most_recent_message.get("content") + if isinstance(content, str): + returned_message = content + else: + for chunk in content: + if chunk.get("type") == "text": + returned_message += chunk.get("text") + + ## CREATE CHAT HISTORY + user_message_types = {"user"} + tool_message_types = {"tool", "function"} + # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them. + new_messages: ChatHistory = [] + msg_i = 0 + + while msg_i < len(messages): + user_content: str = "" + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types: + if isinstance(messages[msg_i]["content"], list): + for m in messages[msg_i]["content"]: + if m.get("type", "") == "text": + user_content += m["text"] + else: + user_content += messages[msg_i]["content"] + msg_i += 1 + + if len(user_content) > 0: + new_messages.append(ChatHistoryUser(role="USER", message=user_content)) + + system_content: str = "" + ## MERGE CONSECUTIVE SYSTEM CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "system": + if isinstance(messages[msg_i]["content"], list): + for m in messages[msg_i]["content"]: + if m.get("type", "") == "text": + system_content += m["text"] + else: + system_content += messages[msg_i]["content"] + msg_i += 1 + + if len(system_content) > 0: + new_messages.append( + ChatHistorySystem(role="SYSTEM", message=system_content) + ) + + assistant_content: str = "" + assistant_tool_calls: List[ToolCallObject] = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content += assistant_text + + if messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke conversion + assistant_tool_calls.extend( + convert_to_cohere_tool_invoke(messages[msg_i]["tool_calls"]) + ) + + if messages[msg_i].get("function_call"): + assistant_tool_calls.extend( + convert_to_cohere_tool_invoke(messages[msg_i]["function_call"]) + ) + + msg_i += 1 + + if len(assistant_content) > 0: + new_messages.append( + ChatHistoryChatBot( + role="CHATBOT", + message=assistant_content, + tool_calls=assistant_tool_calls, + ) + ) + + ## MERGE CONSECUTIVE TOOL RESULTS + tool_results: List[ToolResultObject] = [] + while msg_i < len(messages) and messages[msg_i]["role"] in tool_message_types: + tool_results.append( + convert_openai_message_to_cohere_tool_result( + messages[msg_i], tool_calls + ) + ) + + msg_i += 1 + + if len(tool_results) > 0: + new_messages.append( + ChatHistoryToolResult(role="TOOL", tool_results=tool_results) + ) + + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + + return returned_message, new_messages def cohere_message_pt(messages: list): + tool_calls: List = get_all_tool_calls(messages=messages) prompt = "" tool_results = [] for message in messages: # check if this is a tool_call result if message["role"] == "tool": - tool_result = convert_openai_message_to_cohere_tool_result(message) + tool_result = convert_openai_message_to_cohere_tool_result( + message, tool_calls=tool_calls + ) tool_results.append(tool_result) elif message.get("content"): prompt += message["content"] + "\n\n" diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 960aa6b961..5adbc0d7b5 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1121,7 +1121,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider): assert "hello" in mock_call.call_args.kwargs["headers"] -@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") +# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") @pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("provider", ["vertex_ai"]) @pytest.mark.asyncio @@ -1159,7 +1159,7 @@ async def test_gemini_pro_function_calling(provider, sync_mode): # The result of the tool call is added to the history { "role": "tool", - "tool_call_id": "call_123", + "tool_call_id": "call_123", "content": "27 degrees celsius and clear in San Francisco, CA", }, # Now the assistant can reply with the result of the tool call. @@ -1381,6 +1381,7 @@ async def test_vertexai_aembedding(): except Exception as e: pytest.fail(f"Error occurred: {e}") + @pytest.mark.asyncio def test_tool_name_conversion(): messages = [ @@ -1424,7 +1425,8 @@ def test_tool_name_conversion(): # assert that the last tool response has the corresponding tool name assert ( - translated_messages[-1]["parts"][0]["function_response"]["name"] == "get_weather" + translated_messages[-1]["parts"][0]["function_response"]["name"] + == "get_weather" ) @@ -1585,6 +1587,7 @@ def test_prompt_factory(): print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages") + def test_prompt_factory_nested(): messages = [ {"role": "user", "content": [{"type": "text", "text": "hi"}]}, @@ -1606,4 +1609,4 @@ def test_prompt_factory_nested(): assert "text" in message["parts"][0], "Missing 'text' from 'parts'" assert isinstance( message["parts"][0]["text"], str - ), "'text' value not a string." \ No newline at end of file + ), "'text' value not a string." diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 5138e9b61b..72567e05df 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -408,6 +408,97 @@ def test_completion_claude_3_function_call(model): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True]) +@pytest.mark.parametrize( + "model", + [ + "gpt-3.5-turbo", + "claude-3-opus-20240229", + "command-r", + "anthropic.claude-3-sonnet-20240229-v1:0", + # "azure_ai/command-r-plus" + ], +) +@pytest.mark.asyncio +async def test_model_function_invoke(model, sync_mode): + try: + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + # Assistant replies with a tool call + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "index": 0, + "function": { + "name": "get_weather", + "arguments": '{"location":"San Francisco, CA"}', + }, + } + ], + }, + # The result of the tool call is added to the history + { + "role": "tool", + "tool_call_id": "call_123", + "content": "27 degrees celsius and clear in San Francisco, CA", + }, + # Now the assistant can reply with the result of the tool call. + ] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + } + ] + + data = { + "model": model, + "messages": messages, + "tools": tools, + } + if sync_mode: + response = litellm.completion(**data) + else: + response = await litellm.acompletion(**data) + + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + @pytest.mark.asyncio async def test_anthropic_no_content_error(): """ diff --git a/litellm/types/llms/cohere.py b/litellm/types/llms/cohere.py new file mode 100644 index 0000000000..7112a242f9 --- /dev/null +++ b/litellm/types/llms/cohere.py @@ -0,0 +1,46 @@ +from typing import Iterable, List, Optional, Union + +from typing_extensions import Literal, Required, TypedDict + + +class CallObject(TypedDict): + name: str + parameters: dict + + +class ToolResultObject(TypedDict): + call: CallObject + outputs: List[dict] + + +class ChatHistoryToolResult(TypedDict, total=False): + role: Required[Literal["TOOL"]] + tool_results: List[ToolResultObject] + + +class ToolCallObject(TypedDict): + name: str + parameters: dict + + +class ChatHistoryUser(TypedDict, total=False): + role: Required[Literal["USER"]] + message: str + tool_calls: List[ToolCallObject] + + +class ChatHistorySystem(TypedDict, total=False): + role: Required[Literal["SYSTEM"]] + message: str + tool_calls: List[ToolCallObject] + + +class ChatHistoryChatBot(TypedDict, total=False): + role: Required[Literal["CHATBOT"]] + message: str + tool_calls: List[ToolCallObject] + + +ChatHistory = List[ + Union[ChatHistorySystem, ChatHistoryChatBot, ChatHistoryUser, ChatHistoryToolResult] +]