mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(cohere.py): fix message parsing to handle tool calling correctly
This commit is contained in:
parent
4606b020b5
commit
cceb7b59db
5 changed files with 426 additions and 35 deletions
|
@ -1,13 +1,19 @@
|
|||
import os, types
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from enum import Enum
|
||||
import requests # type: ignore
|
||||
import time, traceback
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Choices, Message, Usage
|
||||
import litellm
|
||||
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import cohere_message_pt
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.cohere import ToolResultObject
|
||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
from .prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
|
@ -112,7 +118,7 @@ class CohereChatConfig:
|
|||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"Request-Source":"unspecified:litellm",
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
@ -196,17 +202,17 @@ def completion(
|
|||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
completion_url = api_base
|
||||
model = model
|
||||
prompt, tool_results = cohere_message_pt(messages=messages)
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(messages=messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
|
@ -221,18 +227,18 @@ def completion(
|
|||
_is_function_call = True
|
||||
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if len(tool_results) > 0:
|
||||
optional_params["tool_results"] = tool_results
|
||||
|
||||
if isinstance(most_recent_message, dict):
|
||||
optional_params["tool_results"] = [most_recent_message]
|
||||
elif isinstance(most_recent_message, str):
|
||||
optional_params["message"] = most_recent_message
|
||||
data = {
|
||||
"model": model,
|
||||
"message": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
input=most_recent_message,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
|
@ -256,7 +262,7 @@ def completion(
|
|||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
input=most_recent_message,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue