fix(cohere.py): fix message parsing to handle tool calling correctly

This commit is contained in:
Krrish Dholakia 2024-07-04 11:13:07 -07:00
parent 4606b020b5
commit cceb7b59db
5 changed files with 426 additions and 35 deletions

View file

@ -408,6 +408,97 @@ def test_completion_claude_3_function_call(model):
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.parametrize(
"model",
[
"gpt-3.5-turbo",
"claude-3-opus-20240229",
"command-r",
"anthropic.claude-3-sonnet-20240229-v1:0",
# "azure_ai/command-r-plus"
],
)
@pytest.mark.asyncio
async def test_model_function_invoke(model, sync_mode):
try:
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
# Assistant replies with a tool call
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"index": 0,
"function": {
"name": "get_weather",
"arguments": '{"location":"San Francisco, CA"}',
},
}
],
},
# The result of the tool call is added to the history
{
"role": "tool",
"tool_call_id": "call_123",
"content": "27 degrees celsius and clear in San Francisco, CA",
},
# Now the assistant can reply with the result of the tool call.
]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
},
}
]
data = {
"model": model,
"messages": messages,
"tools": tools,
}
if sync_mode:
response = litellm.completion(**data)
else:
response = await litellm.acompletion(**data)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Quota exceeded" in str(e):
pass
else:
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
@pytest.mark.asyncio
async def test_anthropic_no_content_error():
"""