(feat) support tool_calling on cohere command-r

This commit is contained in:
ishaan-jaff 2024-03-12 14:24:48 -07:00
parent 54d847fc71
commit 836029b5ab
3 changed files with 216 additions and 1 deletions

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx
from .prompt_templates.factory import cohere_message_pt
class CohereError(Exception): class CohereError(Exception):
@ -201,7 +202,7 @@ def completion(
headers = validate_environment(api_key) headers = validate_environment(api_key)
completion_url = api_base completion_url = api_base
model = model model = model
prompt = " ".join(message["content"] for message in messages) prompt, tool_results = cohere_message_pt(messages=messages)
## Load Config ## Load Config
config = litellm.CohereConfig.get_config() config = litellm.CohereConfig.get_config()
@ -216,6 +217,8 @@ def completion(
_is_function_call = True _is_function_call = True
cohere_tools = construct_cohere_tool(tools=optional_params["tools"]) cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools optional_params["tools"] = cohere_tools
if len(tool_results) > 0:
optional_params["tool_results"] = tool_results
data = { data = {
"model": model, "model": model,
@ -262,6 +265,30 @@ def completion(
except Exception as e: except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response is not []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage ## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = completion_response.get("meta", {}).get("billed_units", {}) billed_units = completion_response.get("meta", {}).get("billed_units", {})

View file

@ -652,6 +652,65 @@ def parse_xml_params(xml_content):
### ###
def convert_openai_message_to_cohere_tool_result(message):
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
},
"""
"""
Cohere tool_results look like:
{
"call": {
"name": "query_daily_sales_report",
"parameters": {
"day": "2023-09-29"
},
"generation_id": "4807c924-9003-4d6b-8069-eda03962c465"
},
"outputs": [
{
"date": "2023-09-29",
"summary": "Total Sales Amount: 10000, Total Units Sold: 250"
}
]
},
"""
tool_call_id = message.get("tool_call_id")
name = message.get("name")
content = message.get("content")
# Create the Cohere tool_result dictionary
cohere_tool_result = {
"call": {
"name": name,
"parameters": {"location": "San Francisco, CA"},
"generation_id": tool_call_id,
},
"outputs": [content],
}
return cohere_tool_result
def cohere_message_pt(messages: list):
prompt = ""
tool_results = []
for message in messages:
# check if this is a tool_call result
if message["role"] == "tool":
tool_result = convert_openai_message_to_cohere_tool_result(message)
tool_results.append(tool_result)
else:
prompt += message["content"]
return prompt, tool_results
def amazon_titan_pt( def amazon_titan_pt(
messages: list, messages: list,
): # format - https://github.com/BerriAI/litellm/issues/1896 ): # format - https://github.com/BerriAI/litellm/issues/1896

View file

@ -12,6 +12,7 @@ import pytest
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
import json
litellm.num_retries = 3 litellm.num_retries = 3
@ -99,3 +100,131 @@ def test_chat_completion_cohere_tool_calling():
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# def get_current_weather(location, unit="fahrenheit"):
# """Get the current weather in a given location"""
# if "tokyo" in location.lower():
# return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
# elif "san francisco" in location.lower():
# return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
# elif "paris" in location.lower():
# return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
# else:
# return json.dumps({"location": location, "temperature": "unknown"})
# def test_chat_completion_cohere_tool_with_result_calling():
# # end to end cohere command-r with tool calling
# # Step 1 - Send available tools
# # Step 2 - Execute results
# # Step 3 - Send results to command-r
# try:
# litellm.set_verbose = True
# import json
# # Step 1 - Send available tools
# tools = [
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
# messages = [
# {
# "role": "user",
# "content": "What is the weather like in Boston?",
# },
# ]
# response = completion(
# model="cohere_chat/command-r",
# messages=messages,
# tools=tools,
# )
# print("Response with tools to call", response)
# print(response)
# # step 2 - Execute results
# tool_calls = response.tool_calls
# available_functions = {
# "get_current_weather": get_current_weather,
# } # only one function in this example, but you can have multiple
# for tool_call in tool_calls:
# function_name = tool_call.function.name
# function_to_call = available_functions[function_name]
# function_args = json.loads(tool_call.function.arguments)
# function_response = function_to_call(
# location=function_args.get("location"),
# unit=function_args.get("unit"),
# )
# messages.append(
# {
# "tool_call_id": tool_call.id,
# "role": "tool",
# "name": function_name,
# "content": function_response,
# }
# ) # extend conversation with function response
# print("messages with tool call results", messages)
# messages = [
# {
# "role": "user",
# "content": "What is the weather like in Boston?",
# },
# {
# "tool_call_id": "tool_1",
# "role": "tool",
# "name": "get_current_weather",
# "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
# },
# ]
# respone = completion(
# model="cohere_chat/command-r",
# messages=messages,
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ],
# )
# print(respone)
except Exception as e:
pytest.fail(f"Error occurred: {e}")