mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
(feat) support tool_calling on cohere command-r
This commit is contained in:
parent
54d847fc71
commit
836029b5ab
3 changed files with 216 additions and 1 deletions
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx
|
||||||
|
from .prompt_templates.factory import cohere_message_pt
|
||||||
|
|
||||||
|
|
||||||
class CohereError(Exception):
|
class CohereError(Exception):
|
||||||
|
@ -201,7 +202,7 @@ def completion(
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
prompt, tool_results = cohere_message_pt(messages=messages)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.CohereConfig.get_config()
|
config = litellm.CohereConfig.get_config()
|
||||||
|
@ -216,6 +217,8 @@ def completion(
|
||||||
_is_function_call = True
|
_is_function_call = True
|
||||||
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
|
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
|
||||||
optional_params["tools"] = cohere_tools
|
optional_params["tools"] = cohere_tools
|
||||||
|
if len(tool_results) > 0:
|
||||||
|
optional_params["tool_results"] = tool_results
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
@ -262,6 +265,30 @@ def completion(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
raise CohereError(message=response.text, status_code=response.status_code)
|
||||||
|
|
||||||
|
## Tool calling response
|
||||||
|
cohere_tools_response = completion_response.get("tool_calls", None)
|
||||||
|
if cohere_tools_response is not None and cohere_tools_response is not []:
|
||||||
|
# convert cohere_tools_response to OpenAI response format
|
||||||
|
tool_calls = []
|
||||||
|
for tool in cohere_tools_response:
|
||||||
|
function_name = tool.get("name", "")
|
||||||
|
generation_id = tool.get("generation_id", "")
|
||||||
|
parameters = tool.get("parameters", {})
|
||||||
|
tool_call = {
|
||||||
|
"id": f"call_{generation_id}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(parameters),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
|
|
||||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||||
billed_units = completion_response.get("meta", {}).get("billed_units", {})
|
billed_units = completion_response.get("meta", {}).get("billed_units", {})
|
||||||
|
|
||||||
|
|
|
@ -652,6 +652,65 @@ def parse_xml_params(xml_content):
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_message_to_cohere_tool_result(message):
|
||||||
|
"""
|
||||||
|
OpenAI message with a tool result looks like:
|
||||||
|
{
|
||||||
|
"tool_call_id": "tool_1",
|
||||||
|
"role": "tool",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Cohere tool_results look like:
|
||||||
|
{
|
||||||
|
"call": {
|
||||||
|
"name": "query_daily_sales_report",
|
||||||
|
"parameters": {
|
||||||
|
"day": "2023-09-29"
|
||||||
|
},
|
||||||
|
"generation_id": "4807c924-9003-4d6b-8069-eda03962c465"
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"date": "2023-09-29",
|
||||||
|
"summary": "Total Sales Amount: 10000, Total Units Sold: 250"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
name = message.get("name")
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
# Create the Cohere tool_result dictionary
|
||||||
|
cohere_tool_result = {
|
||||||
|
"call": {
|
||||||
|
"name": name,
|
||||||
|
"parameters": {"location": "San Francisco, CA"},
|
||||||
|
"generation_id": tool_call_id,
|
||||||
|
},
|
||||||
|
"outputs": [content],
|
||||||
|
}
|
||||||
|
return cohere_tool_result
|
||||||
|
|
||||||
|
|
||||||
|
def cohere_message_pt(messages: list):
|
||||||
|
prompt = ""
|
||||||
|
tool_results = []
|
||||||
|
for message in messages:
|
||||||
|
# check if this is a tool_call result
|
||||||
|
if message["role"] == "tool":
|
||||||
|
tool_result = convert_openai_message_to_cohere_tool_result(message)
|
||||||
|
tool_results.append(tool_result)
|
||||||
|
else:
|
||||||
|
prompt += message["content"]
|
||||||
|
return prompt, tool_results
|
||||||
|
|
||||||
|
|
||||||
def amazon_titan_pt(
|
def amazon_titan_pt(
|
||||||
messages: list,
|
messages: list,
|
||||||
): # format - https://github.com/BerriAI/litellm/issues/1896
|
): # format - https://github.com/BerriAI/litellm/issues/1896
|
||||||
|
|
|
@ -12,6 +12,7 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
import json
|
||||||
|
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
|
|
||||||
|
@ -99,3 +100,131 @@ def test_chat_completion_cohere_tool_calling():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# def get_current_weather(location, unit="fahrenheit"):
|
||||||
|
# """Get the current weather in a given location"""
|
||||||
|
# if "tokyo" in location.lower():
|
||||||
|
# return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
|
||||||
|
# elif "san francisco" in location.lower():
|
||||||
|
# return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
|
||||||
|
# elif "paris" in location.lower():
|
||||||
|
# return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
|
||||||
|
# else:
|
||||||
|
# return json.dumps({"location": location, "temperature": "unknown"})
|
||||||
|
|
||||||
|
# def test_chat_completion_cohere_tool_with_result_calling():
|
||||||
|
# # end to end cohere command-r with tool calling
|
||||||
|
# # Step 1 - Send available tools
|
||||||
|
# # Step 2 - Execute results
|
||||||
|
# # Step 3 - Send results to command-r
|
||||||
|
# try:
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
# import json
|
||||||
|
|
||||||
|
# # Step 1 - Send available tools
|
||||||
|
# tools = [
|
||||||
|
# {
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "description": "Get the current weather in a given location",
|
||||||
|
# "parameters": {
|
||||||
|
# "type": "object",
|
||||||
|
# "properties": {
|
||||||
|
# "location": {
|
||||||
|
# "type": "string",
|
||||||
|
# "description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
# },
|
||||||
|
# "unit": {
|
||||||
|
# "type": "string",
|
||||||
|
# "enum": ["celsius", "fahrenheit"],
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# "required": ["location"],
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# messages = [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "What is the weather like in Boston?",
|
||||||
|
# },
|
||||||
|
# ]
|
||||||
|
# response = completion(
|
||||||
|
# model="cohere_chat/command-r",
|
||||||
|
# messages=messages,
|
||||||
|
# tools=tools,
|
||||||
|
# )
|
||||||
|
# print("Response with tools to call", response)
|
||||||
|
# print(response)
|
||||||
|
|
||||||
|
# # step 2 - Execute results
|
||||||
|
# tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
# available_functions = {
|
||||||
|
# "get_current_weather": get_current_weather,
|
||||||
|
# } # only one function in this example, but you can have multiple
|
||||||
|
|
||||||
|
# for tool_call in tool_calls:
|
||||||
|
# function_name = tool_call.function.name
|
||||||
|
# function_to_call = available_functions[function_name]
|
||||||
|
# function_args = json.loads(tool_call.function.arguments)
|
||||||
|
# function_response = function_to_call(
|
||||||
|
# location=function_args.get("location"),
|
||||||
|
# unit=function_args.get("unit"),
|
||||||
|
# )
|
||||||
|
# messages.append(
|
||||||
|
# {
|
||||||
|
# "tool_call_id": tool_call.id,
|
||||||
|
# "role": "tool",
|
||||||
|
# "name": function_name,
|
||||||
|
# "content": function_response,
|
||||||
|
# }
|
||||||
|
# ) # extend conversation with function response
|
||||||
|
|
||||||
|
# print("messages with tool call results", messages)
|
||||||
|
|
||||||
|
# messages = [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "What is the weather like in Boston?",
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tool_call_id": "tool_1",
|
||||||
|
# "role": "tool",
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "content": {"location": "San Francisco, CA", "unit": "fahrenheit", "temperature": "72"},
|
||||||
|
# },
|
||||||
|
# ]
|
||||||
|
# respone = completion(
|
||||||
|
# model="cohere_chat/command-r",
|
||||||
|
# messages=messages,
|
||||||
|
# tools=[
|
||||||
|
# {
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {
|
||||||
|
# "name": "get_current_weather",
|
||||||
|
# "description": "Get the current weather in a given location",
|
||||||
|
# "parameters": {
|
||||||
|
# "type": "object",
|
||||||
|
# "properties": {
|
||||||
|
# "location": {
|
||||||
|
# "type": "string",
|
||||||
|
# "description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
# },
|
||||||
|
# "unit": {
|
||||||
|
# "type": "string",
|
||||||
|
# "enum": ["celsius", "fahrenheit"],
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# "required": ["location"],
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
# print(respone)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue