mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(feat) support tool_calling on cohere command-r
This commit is contained in:
parent
54d847fc71
commit
836029b5ab
3 changed files with 216 additions and 1 deletions
|
@ -7,6 +7,7 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
from .prompt_templates.factory import cohere_message_pt
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
|
@ -201,7 +202,7 @@ def completion(
|
|||
headers = validate_environment(api_key)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
prompt, tool_results = cohere_message_pt(messages=messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
|
@ -216,6 +217,8 @@ def completion(
|
|||
_is_function_call = True
|
||||
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if len(tool_results) > 0:
|
||||
optional_params["tool_results"] = tool_results
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
|
@ -262,6 +265,30 @@ def completion(
|
|||
except Exception as e:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = completion_response.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response is not []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
generation_id = tool.get("generation_id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": f"call_{generation_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = completion_response.get("meta", {}).get("billed_units", {})
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue