fix(cohere.py): fix message parsing to handle tool calling correctly

This commit is contained in:
Krrish Dholakia 2024-07-04 11:13:07 -07:00
parent 4606b020b5
commit cceb7b59db
5 changed files with 426 additions and 35 deletions

View file

@ -1,13 +1,19 @@
import os, types
import json
import os
import time
import traceback
import types
from enum import Enum
import requests # type: ignore
import time, traceback
from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm
import httpx # type: ignore
from .prompt_templates.factory import cohere_message_pt
import requests # type: ignore
import litellm
from litellm.types.llms.cohere import ToolResultObject
from litellm.utils import Choices, Message, ModelResponse, Usage
from .prompt_templates.factory import cohere_message_pt, cohere_messages_pt_v2
class CohereError(Exception):
@ -112,7 +118,7 @@ class CohereChatConfig:
def validate_environment(api_key):
headers = {
"Request-Source":"unspecified:litellm",
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
@ -196,17 +202,17 @@ def completion(
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
completion_url = api_base
model = model
prompt, tool_results = cohere_message_pt(messages=messages)
most_recent_message, chat_history = cohere_messages_pt_v2(messages=messages)
## Load Config
config = litellm.CohereConfig.get_config()
@ -221,18 +227,18 @@ def completion(
_is_function_call = True
cohere_tools = construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if len(tool_results) > 0:
optional_params["tool_results"] = tool_results
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
data = {
"model": model,
"message": prompt,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
input=most_recent_message,
api_key=api_key,
additional_args={
"complete_input_dict": data,
@ -256,7 +262,7 @@ def completion(
else:
## LOGGING
logging_obj.post_call(
input=prompt,
input=most_recent_message,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},