(feat) support tool_calling on cohere command-r

This commit is contained in:
ishaan-jaff 2024-03-12 14:24:48 -07:00
parent 54d847fc71
commit 836029b5ab
3 changed files with 216 additions and 1 deletions

View file

@ -7,6 +7,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx
from .prompt_templates.factory import cohere_message_pt
class CohereError(Exception):
@ -201,7 +202,7 @@ def completion(
headers = validate_environment(api_key)
completion_url = api_base
model = model
prompt = " ".join(message["content"] for message in messages)
prompt, tool_results = cohere_message_pt(messages=messages)
## Load Config
config = litellm.CohereConfig.get_config()
@ -216,6 +217,8 @@ def completion(
_is_function_call = True
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if len(tool_results) > 0:
optional_params["tool_results"] = tool_results
data = {
"model": model,
@ -262,6 +265,30 @@ def completion(
except Exception as e:
raise CohereError(message=response.text, status_code=response.status_code)
## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response is not []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = completion_response.get("meta", {}).get("billed_units", {})