fix(cohere.py): fix message parsing to handle tool calling correctly

This commit is contained in:
Krrish Dholakia 2024-07-04 11:13:07 -07:00
parent 4606b020b5
commit cceb7b59db
5 changed files with 426 additions and 35 deletions

View file

@ -1,13 +1,19 @@
import os, types
import json import json
import os
import time
import traceback
import types
from enum import Enum from enum import Enum
import requests # type: ignore
import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx # type: ignore import httpx # type: ignore
from .prompt_templates.factory import cohere_message_pt import requests # type: ignore
import litellm
from litellm.types.llms.cohere import ToolResultObject
from litellm.utils import Choices, Message, ModelResponse, Usage
from .prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
class CohereError(Exception): class CohereError(Exception):
@ -112,7 +118,7 @@ class CohereChatConfig:
def validate_environment(api_key): def validate_environment(api_key):
headers = { headers = {
"Request-Source":"unspecified:litellm", "Request-Source": "unspecified:litellm",
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
} }
@ -196,17 +202,17 @@ def completion(
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
optional_params: dict,
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
headers = validate_environment(api_key) headers = validate_environment(api_key)
completion_url = api_base completion_url = api_base
model = model model = model
prompt, tool_results = cohere_message_pt(messages=messages) most_recent_message, chat_history = cohere_messages_pt_v2(messages=messages)
## Load Config ## Load Config
config = litellm.CohereConfig.get_config() config = litellm.CohereConfig.get_config()
@ -221,18 +227,18 @@ def completion(
_is_function_call = True _is_function_call = True
cohere_tools = construct_cohere_tool(tools=optional_params["tools"]) cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools optional_params["tools"] = cohere_tools
if len(tool_results) > 0: if isinstance(most_recent_message, dict):
optional_params["tool_results"] = tool_results optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
data = { data = {
"model": model, "model": model,
"message": prompt,
**optional_params, **optional_params,
} }
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=most_recent_message,
api_key=api_key, api_key=api_key,
additional_args={ additional_args={
"complete_input_dict": data, "complete_input_dict": data,
@ -256,7 +262,7 @@ def completion(
else: else:
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=most_recent_message,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},

View file

@ -1415,16 +1415,37 @@ def convert_to_documents(
return documents return documents
def convert_openai_message_to_cohere_tool_result(message): from litellm.types.llms.cohere import (
CallObject,
ChatHistory,
ChatHistoryChatBot,
ChatHistorySystem,
ChatHistoryToolResult,
ChatHistoryUser,
ToolCallObject,
ToolResultObject,
)
def convert_openai_message_to_cohere_tool_result(
message, tool_calls: List
) -> ToolResultObject:
""" """
OpenAI message with a tool result looks like: OpenAI message with a tool result looks like:
{ {
"tool_call_id": "tool_1", "tool_call_id": "tool_1",
"role": "tool", "role": "tool",
"name": "get_current_weather",
"content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"}, "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
}, },
""" """
"""
OpenAI message with a function call looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
""" """
Cohere tool_results look like: Cohere tool_results look like:
@ -1434,7 +1455,6 @@ def convert_openai_message_to_cohere_tool_result(message):
"parameters": { "parameters": {
"day": "2023-09-29" "day": "2023-09-29"
}, },
"generation_id": "4807c924-9003-4d6b-8069-eda03962c465"
}, },
"outputs": [ "outputs": [
{ {
@ -1444,30 +1464,255 @@ def convert_openai_message_to_cohere_tool_result(message):
] ]
}, },
""" """
content_str: str = message.get("content", "")
if len(content_str) > 0:
try:
content = json.loads(content_str)
except json.JSONDecodeError:
content = {"result": content_str}
else:
content = {}
name = ""
arguments = {}
# Recover name from last message with tool calls
if len(tool_calls) > 0:
tools = tool_calls
msg_tool_call_id = message.get("tool_call_id", None)
for tool in tools:
prev_tool_call_id = tool.get("id", None)
if (
msg_tool_call_id
and prev_tool_call_id
and msg_tool_call_id == prev_tool_call_id
):
name = tool.get("function", {}).get("name", "")
arguments_str = tool.get("function", {}).get("arguments", "")
if arguments_str is not None and len(arguments_str) > 0:
arguments = json.loads(arguments_str)
tool_call_id = message.get("tool_call_id") if message["role"] == "function":
name = message.get("name") name = message.get("name")
content = message.get("content") cohere_tool_result: ToolResultObject = {
"call": CallObject(name=name, parameters=arguments),
"outputs": [content],
}
return cohere_tool_result
else:
# We can't determine from openai message format whether it's a successful or
# error call result so default to the successful result template
# Create the Cohere tool_result dictionary cohere_tool_result = {
cohere_tool_result = { "call": CallObject(name=name, parameters=arguments),
"call": { "outputs": [content],
"name": name, }
"parameters": {"location": "San Francisco, CA"}, return cohere_tool_result
"generation_id": tool_call_id,
},
"outputs": convert_to_documents(content), def get_all_tool_calls(messages: List) -> List:
"""
Returns extracted list of `tool_calls`.
Done to handle openai no longer returning tool call 'name' in tool results.
"""
tool_calls: List = []
for m in messages:
if m.get("tool_calls", None) is not None:
if isinstance(m["tool_calls"], list):
tool_calls.extend(m["tool_calls"])
return tool_calls
def convert_to_cohere_tool_invoke(tool_calls: list) -> List[ToolCallObject]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Cohere tool invokes:
{
"role": "CHATBOT",
"tool_calls": [{"name": "get_weather", "parameters": {"location": "San Francisco, CA"}}]
} }
return cohere_tool_result """
cohere_tool_invoke: List[ToolCallObject] = [
{
"name": get_attribute_or_key(
get_attribute_or_key(tool, "function"), "name"
),
"parameters": json.loads(
get_attribute_or_key(
get_attribute_or_key(tool, "function"), "arguments"
)
),
}
for tool in tool_calls
if get_attribute_or_key(tool, "type") == "function"
]
return cohere_tool_invoke
def cohere_messages_pt_v2(
messages: List,
) -> Tuple[Union[str, ToolResultObject], ChatHistory]:
"""
Returns a tuple(Union[tool_result, message], chat_history)
- if last message is tool result -> return 'tool_result'
- if last message is text -> return message (str)
- return preceding messages as 'chat_history'
Note:
- cannot specify message if the last entry in chat history contains tool results
- message must be at least 1 token long or tool results must be specified.
"""
tool_calls: List = get_all_tool_calls(messages=messages)
## GET MOST RECENT MESSAGE
most_recent_message = messages.pop(-1)
returned_message: Union[ToolResultObject, str] = ""
if (
most_recent_message.get("role", "") is not None
and most_recent_message["role"] == "tool"
):
# tool result
returned_message = convert_openai_message_to_cohere_tool_result(
most_recent_message, tool_calls
)
else:
content: Union[str, List] = most_recent_message.get("content")
if isinstance(content, str):
returned_message = content
else:
for chunk in content:
if chunk.get("type") == "text":
returned_message += chunk.get("text")
## CREATE CHAT HISTORY
user_message_types = {"user"}
tool_message_types = {"tool", "function"}
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, merge them.
new_messages: ChatHistory = []
msg_i = 0
while msg_i < len(messages):
user_content: str = ""
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
if isinstance(messages[msg_i]["content"], list):
for m in messages[msg_i]["content"]:
if m.get("type", "") == "text":
user_content += m["text"]
else:
user_content += messages[msg_i]["content"]
msg_i += 1
if len(user_content) > 0:
new_messages.append(ChatHistoryUser(role="USER", message=user_content))
system_content: str = ""
## MERGE CONSECUTIVE SYSTEM CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "system":
if isinstance(messages[msg_i]["content"], list):
for m in messages[msg_i]["content"]:
if m.get("type", "") == "text":
system_content += m["text"]
else:
system_content += messages[msg_i]["content"]
msg_i += 1
if len(system_content) > 0:
new_messages.append(
ChatHistorySystem(role="SYSTEM", message=system_content)
)
assistant_content: str = ""
assistant_tool_calls: List[ToolCallObject] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content += assistant_text
if messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke conversion
assistant_tool_calls.extend(
convert_to_cohere_tool_invoke(messages[msg_i]["tool_calls"])
)
if messages[msg_i].get("function_call"):
assistant_tool_calls.extend(
convert_to_cohere_tool_invoke(messages[msg_i]["function_call"])
)
msg_i += 1
if len(assistant_content) > 0:
new_messages.append(
ChatHistoryChatBot(
role="CHATBOT",
message=assistant_content,
tool_calls=assistant_tool_calls,
)
)
## MERGE CONSECUTIVE TOOL RESULTS
tool_results: List[ToolResultObject] = []
while msg_i < len(messages) and messages[msg_i]["role"] in tool_message_types:
tool_results.append(
convert_openai_message_to_cohere_tool_result(
messages[msg_i], tool_calls
)
)
msg_i += 1
if len(tool_results) > 0:
new_messages.append(
ChatHistoryToolResult(role="TOOL", tool_results=tool_results)
)
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return returned_message, new_messages
def cohere_message_pt(messages: list): def cohere_message_pt(messages: list):
tool_calls: List = get_all_tool_calls(messages=messages)
prompt = "" prompt = ""
tool_results = [] tool_results = []
for message in messages: for message in messages:
# check if this is a tool_call result # check if this is a tool_call result
if message["role"] == "tool": if message["role"] == "tool":
tool_result = convert_openai_message_to_cohere_tool_result(message) tool_result = convert_openai_message_to_cohere_tool_result(
message, tool_calls=tool_calls
)
tool_results.append(tool_result) tool_results.append(tool_result)
elif message.get("content"): elif message.get("content"):
prompt += message["content"] + "\n\n" prompt += message["content"] + "\n\n"

View file

@ -1121,7 +1121,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
assert "hello" in mock_call.call_args.kwargs["headers"] assert "hello" in mock_call.call_args.kwargs["headers"]
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") # @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.parametrize("provider", ["vertex_ai"]) @pytest.mark.parametrize("provider", ["vertex_ai"])
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1159,7 +1159,7 @@ async def test_gemini_pro_function_calling(provider, sync_mode):
# The result of the tool call is added to the history # The result of the tool call is added to the history
{ {
"role": "tool", "role": "tool",
"tool_call_id": "call_123", "tool_call_id": "call_123",
"content": "27 degrees celsius and clear in San Francisco, CA", "content": "27 degrees celsius and clear in San Francisco, CA",
}, },
# Now the assistant can reply with the result of the tool call. # Now the assistant can reply with the result of the tool call.
@ -1381,6 +1381,7 @@ async def test_vertexai_aembedding():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio @pytest.mark.asyncio
def test_tool_name_conversion(): def test_tool_name_conversion():
messages = [ messages = [
@ -1424,7 +1425,8 @@ def test_tool_name_conversion():
# assert that the last tool response has the corresponding tool name # assert that the last tool response has the corresponding tool name
assert ( assert (
translated_messages[-1]["parts"][0]["function_response"]["name"] == "get_weather" translated_messages[-1]["parts"][0]["function_response"]["name"]
== "get_weather"
) )
@ -1585,6 +1587,7 @@ def test_prompt_factory():
print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages") print(f"\n\ntranslated_messages: {translated_messages}\ntranslated_messages")
def test_prompt_factory_nested(): def test_prompt_factory_nested():
messages = [ messages = [
{"role": "user", "content": [{"type": "text", "text": "hi"}]}, {"role": "user", "content": [{"type": "text", "text": "hi"}]},
@ -1606,4 +1609,4 @@ def test_prompt_factory_nested():
assert "text" in message["parts"][0], "Missing 'text' from 'parts'" assert "text" in message["parts"][0], "Missing 'text' from 'parts'"
assert isinstance( assert isinstance(
message["parts"][0]["text"], str message["parts"][0]["text"], str
), "'text' value not a string." ), "'text' value not a string."

View file

@ -408,6 +408,97 @@ def test_completion_claude_3_function_call(model):
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.parametrize(
"model",
[
"gpt-3.5-turbo",
"claude-3-opus-20240229",
"command-r",
"anthropic.claude-3-sonnet-20240229-v1:0",
# "azure_ai/command-r-plus"
],
)
@pytest.mark.asyncio
async def test_model_function_invoke(model, sync_mode):
try:
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
# Assistant replies with a tool call
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"index": 0,
"function": {
"name": "get_weather",
"arguments": '{"location":"San Francisco, CA"}',
},
}
],
},
# The result of the tool call is added to the history
{
"role": "tool",
"tool_call_id": "call_123",
"content": "27 degrees celsius and clear in San Francisco, CA",
},
# Now the assistant can reply with the result of the tool call.
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
]
data = {
"model": model,
"messages": messages,
"tools": tools,
}
if sync_mode:
response = litellm.completion(**data)
else:
response = await litellm.acompletion(**data)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Quota exceeded" in str(e):
pass
else:
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_anthropic_no_content_error(): async def test_anthropic_no_content_error():
""" """

View file

@ -0,0 +1,46 @@
from typing import Iterable, List, Optional, Union
from typing_extensions import Literal, Required, TypedDict
class CallObject(TypedDict):
name: str
parameters: dict
class ToolResultObject(TypedDict):
call: CallObject
outputs: List[dict]
class ChatHistoryToolResult(TypedDict, total=False):
role: Required[Literal["TOOL"]]
tool_results: List[ToolResultObject]
class ToolCallObject(TypedDict):
name: str
parameters: dict
class ChatHistoryUser(TypedDict, total=False):
role: Required[Literal["USER"]]
message: str
tool_calls: List[ToolCallObject]
class ChatHistorySystem(TypedDict, total=False):
role: Required[Literal["SYSTEM"]]
message: str
tool_calls: List[ToolCallObject]
class ChatHistoryChatBot(TypedDict, total=False):
role: Required[Literal["CHATBOT"]]
message: str
tool_calls: List[ToolCallObject]
ChatHistory = List[
Union[ChatHistorySystem, ChatHistoryChatBot, ChatHistoryUser, ChatHistoryToolResult]
]