mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
(feat) support tool_calling on cohere command-r
This commit is contained in:
parent
54d847fc71
commit
836029b5ab
3 changed files with 216 additions and 1 deletions
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
from .prompt_templates.factory import cohere_message_pt
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
|
@ -201,7 +202,7 @@ def completion(
|
|||
headers = validate_environment(api_key)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
prompt, tool_results = cohere_message_pt(messages=messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
|
@ -216,6 +217,8 @@ def completion(
|
|||
_is_function_call = True
|
||||
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if len(tool_results) > 0:
|
||||
optional_params["tool_results"] = tool_results
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
@ -262,6 +265,30 @@ def completion(
|
|||
except Exception as e:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = completion_response.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response is not []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
generation_id = tool.get("generation_id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": f"call_{generation_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = completion_response.get("meta", {}).get("billed_units", {})
|
||||
|
||||
|
|
|
@ -652,6 +652,65 @@ def parse_xml_params(xml_content):
|
|||
###
|
||||
|
||||
|
||||
def convert_openai_message_to_cohere_tool_result(message):
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
{
|
||||
"tool_call_id": "tool_1",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
|
||||
},
|
||||
"""
|
||||
|
||||
"""
|
||||
Cohere tool_results look like:
|
||||
{
|
||||
"call": {
|
||||
"name": "query_daily_sales_report",
|
||||
"parameters": {
|
||||
"day": "2023-09-29"
|
||||
},
|
||||
"generation_id": "4807c924-9003-4d6b-8069-eda03962c465"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"date": "2023-09-29",
|
||||
"summary": "Total Sales Amount: 10000, Total Units Sold: 250"
|
||||
}
|
||||
]
|
||||
},
|
||||
"""
|
||||
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
name = message.get("name")
|
||||
content = message.get("content")
|
||||
|
||||
# Create the Cohere tool_result dictionary
|
||||
cohere_tool_result = {
|
||||
"call": {
|
||||
"name": name,
|
||||
"parameters": {"location": "San Francisco, CA"},
|
||||
"generation_id": tool_call_id,
|
||||
},
|
||||
"outputs": [content],
|
||||
}
|
||||
return cohere_tool_result
|
||||
|
||||
|
||||
def cohere_message_pt(messages: list):
|
||||
prompt = ""
|
||||
tool_results = []
|
||||
for message in messages:
|
||||
# check if this is a tool_call result
|
||||
if message["role"] == "tool":
|
||||
tool_result = convert_openai_message_to_cohere_tool_result(message)
|
||||
tool_results.append(tool_result)
|
||||
else:
|
||||
prompt += message["content"]
|
||||
return prompt, tool_results
|
||||
|
||||
|
||||
def amazon_titan_pt(
|
||||
messages: list,
|
||||
): # format - https://github.com/BerriAI/litellm/issues/1896
|
||||
|
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
import json
|
||||
|
||||
litellm.num_retries = 3
|
||||
|
||||
|
@ -99,3 +100,131 @@ def test_chat_completion_cohere_tool_calling():
|
|||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# def get_current_weather(location, unit="fahrenheit"):
|
||||
# """Get the current weather in a given location"""
|
||||
# if "tokyo" in location.lower():
|
||||
# return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
|
||||
# elif "san francisco" in location.lower():
|
||||
# return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
|
||||
# elif "paris" in location.lower():
|
||||
# return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
|
||||
# else:
|
||||
# return json.dumps({"location": location, "temperature": "unknown"})
|
||||
|
||||
# def test_chat_completion_cohere_tool_with_result_calling():
|
||||
# # end to end cohere command-r with tool calling
|
||||
# # Step 1 - Send available tools
|
||||
# # Step 2 - Execute results
|
||||
# # Step 3 - Send results to command-r
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# import json
|
||||
|
||||
# # Step 1 - Send available tools
|
||||
# tools = [
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"],
|
||||
# },
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "What is the weather like in Boston?",
|
||||
# },
|
||||
# ]
|
||||
# response = completion(
|
||||
# model="cohere_chat/command-r",
|
||||
# messages=messages,
|
||||
# tools=tools,
|
||||
# )
|
||||
# print("Response with tools to call", response)
|
||||
# print(response)
|
||||
|
||||
# # step 2 - Execute results
|
||||
# tool_calls = response.tool_calls
|
||||
|
||||
# available_functions = {
|
||||
# "get_current_weather": get_current_weather,
|
||||
# } # only one function in this example, but you can have multiple
|
||||
|
||||
# for tool_call in tool_calls:
|
||||
# function_name = tool_call.function.name
|
||||
# function_to_call = available_functions[function_name]
|
||||
# function_args = json.loads(tool_call.function.arguments)
|
||||
# function_response = function_to_call(
|
||||
# location=function_args.get("location"),
|
||||
# unit=function_args.get("unit"),
|
||||
# )
|
||||
# messages.append(
|
||||
# {
|
||||
# "tool_call_id": tool_call.id,
|
||||
# "role": "tool",
|
||||
# "name": function_name,
|
||||
# "content": function_response,
|
||||
# }
|
||||
# ) # extend conversation with function response
|
||||
|
||||
# print("messages with tool call results", messages)
|
||||
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "What is the weather like in Boston?",
|
||||
# },
|
||||
# {
|
||||
# "tool_call_id": "tool_1",
|
||||
# "role": "tool",
|
||||
# "name": "get_current_weather",
|
||||
# "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
|
||||
# },
|
||||
# ]
|
||||
# respone = completion(
|
||||
# model="cohere_chat/command-r",
|
||||
# messages=messages,
|
||||
# tools=[
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"],
|
||||
# },
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
# ],
|
||||
# )
|
||||
# print(respone)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue