mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Handle fireworks ai tool calling response (#10130)
* feat(fireworks_ai/chat): handle tool calling with fireworks ai correctly Fixes https://github.com/BerriAI/litellm/issues/7209 * fix(utils.py): handle none type in message * fix: fix model name in test * fix(utils.py): fix validate check for openai messages * fix: fix model returned * fix(main.py): fix text completion routing * test: update testing * test: skip test - cohere having RBAC issues
This commit is contained in:
parent
b4f2b3dad1
commit
2508ca71cb
9 changed files with 242 additions and 74 deletions
|
@ -229,13 +229,17 @@ class BaseLLMHTTPHandler:
|
|||
api_key: Optional[str] = None,
|
||||
headers: Optional[dict] = {},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
provider_config: Optional[BaseConfig] = None,
|
||||
):
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||
fake_stream = fake_stream or optional_params.pop("fake_stream", False)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
provider_config = (
|
||||
provider_config
|
||||
or ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
)
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -1,15 +1,33 @@
|
|||
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionToolParam,
|
||||
OpenAIChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import ProviderSpecificModelInfo
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ProviderSpecificModelInfo,
|
||||
)
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import FireworksAIException
|
||||
|
||||
|
||||
class FireworksAIConfig(OpenAIGPTConfig):
|
||||
|
@ -219,6 +237,94 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
|||
headers=headers,
|
||||
)
|
||||
|
||||
def _handle_message_content_with_tool_calls(
|
||||
self,
|
||||
message: Message,
|
||||
tool_calls: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Message:
|
||||
"""
|
||||
Fireworks AI sends tool calls in the content field instead of tool_calls
|
||||
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/7209#issuecomment-2813208780
|
||||
"""
|
||||
if (
|
||||
tool_calls is not None
|
||||
and message.content is not None
|
||||
and message.tool_calls is None
|
||||
):
|
||||
try:
|
||||
function = Function(**json.loads(message.content))
|
||||
if function.name != RESPONSE_FORMAT_TOOL_NAME and function.name in [
|
||||
tool["function"]["name"] for tool in tool_calls
|
||||
]:
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
function=function, id=str(uuid.uuid4()), type="function"
|
||||
)
|
||||
message.tool_calls = [tool_call]
|
||||
|
||||
message.content = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return message
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise FireworksAIException(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
|
||||
additional_headers = get_response_headers(raw_response_headers)
|
||||
|
||||
response = ModelResponse(**completion_response)
|
||||
|
||||
if response.model is not None:
|
||||
response.model = "fireworks_ai/" + response.model
|
||||
|
||||
## FIREWORKS AI sends tool calls in the content field instead of tool_calls
|
||||
for choice in response.choices:
|
||||
cast(
|
||||
Choices, choice
|
||||
).message = self._handle_message_content_with_tool_calls(
|
||||
message=cast(Choices, choice).message,
|
||||
tool_calls=optional_params.get("tools", None),
|
||||
)
|
||||
|
||||
response._hidden_params = {"additional_headers": additional_headers}
|
||||
|
||||
return response
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
|
|
@ -1435,6 +1435,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
encoding=encoding,
|
||||
stream=stream,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
@ -1596,6 +1597,37 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
additional_args={"headers": headers},
|
||||
)
|
||||
response = _response
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
encoding=encoding,
|
||||
stream=stream,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
raise e
|
||||
|
||||
elif custom_llm_provider == "groq":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||
|
|
|
@ -378,12 +378,18 @@ class Function(OpenAIObject):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
arguments: Optional[Union[Dict, str]],
|
||||
arguments: Optional[Union[Dict, str]] = None,
|
||||
name: Optional[str] = None,
|
||||
**params,
|
||||
):
|
||||
if arguments is None:
|
||||
arguments = ""
|
||||
if params.get("parameters", None) is not None and isinstance(
|
||||
params["parameters"], dict
|
||||
):
|
||||
arguments = json.dumps(params["parameters"])
|
||||
params.pop("parameters")
|
||||
else:
|
||||
arguments = ""
|
||||
elif isinstance(arguments, Dict):
|
||||
arguments = json.dumps(arguments)
|
||||
else:
|
||||
|
@ -392,7 +398,7 @@ class Function(OpenAIObject):
|
|||
name = name
|
||||
|
||||
# Build a dictionary with the structure your BaseModel expects
|
||||
data = {"arguments": arguments, "name": name, **params}
|
||||
data = {"arguments": arguments, "name": name}
|
||||
|
||||
super(Function, self).__init__(**data)
|
||||
|
||||
|
|
|
@ -6264,24 +6264,27 @@ def validate_and_fix_openai_messages(messages: List):
|
|||
|
||||
Handles missing role for assistant messages.
|
||||
"""
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not message.get("role"):
|
||||
message["role"] = "assistant"
|
||||
if message.get("tool_calls"):
|
||||
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
|
||||
return validate_chat_completion_messages(messages=messages)
|
||||
|
||||
convert_msg_to_dict = cast(AllMessageValues, convert_to_dict(message))
|
||||
cleaned_message = cleanup_none_field_in_message(message=convert_msg_to_dict)
|
||||
new_messages.append(cleaned_message)
|
||||
return validate_chat_completion_user_messages(messages=new_messages)
|
||||
|
||||
|
||||
def validate_chat_completion_messages(messages: List[AllMessageValues]):
|
||||
def cleanup_none_field_in_message(message: AllMessageValues):
|
||||
"""
|
||||
Ensures all messages are valid OpenAI chat completion messages.
|
||||
Cleans up the message by removing the none field.
|
||||
|
||||
remove None fields in the message - e.g. {"function": None} - some providers raise validation errors
|
||||
"""
|
||||
# 1. convert all messages to dict
|
||||
messages = [
|
||||
cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages
|
||||
]
|
||||
# 2. validate user messages
|
||||
return validate_chat_completion_user_messages(messages=messages)
|
||||
new_message = message.copy()
|
||||
return {k: v for k, v in new_message.items() if v is not None}
|
||||
|
||||
|
||||
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
|
||||
|
||||
|
||||
def test_handle_message_content_with_tool_calls():
|
||||
config = FireworksAIConfig()
|
||||
message = Message(
|
||||
content='{"type": "function", "name": "get_current_weather", "parameters": {"location": "Boston, MA", "unit": "fahrenheit"}}',
|
||||
role="assistant",
|
||||
tool_calls=None,
|
||||
function_call=None,
|
||||
provider_specific_fields=None,
|
||||
)
|
||||
expected_tool_call = ChatCompletionMessageToolCall(
|
||||
function=Function(**json.loads(message.content)), id=None, type=None
|
||||
)
|
||||
tool_calls = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
updated_message = config._handle_message_content_with_tool_calls(
|
||||
message, tool_calls
|
||||
)
|
||||
assert updated_message.tool_calls is not None
|
||||
assert len(updated_message.tool_calls) == 1
|
||||
assert updated_message.tool_calls[0].function.name == "get_current_weather"
|
||||
assert (
|
||||
updated_message.tool_calls[0].function.arguments
|
||||
== expected_tool_call.function.arguments
|
||||
)
|
|
@ -896,6 +896,13 @@ class BaseLLMChatTest(ABC):
|
|||
assert response is not None
|
||||
|
||||
# if the provider did not return any tool calls do not make a subsequent llm api call
|
||||
if response.choices[0].message.content is not None:
|
||||
try:
|
||||
json.loads(response.choices[0].message.content)
|
||||
pytest.fail(f"Tool call returned in content instead of tool_calls")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
pass
|
||||
if response.choices[0].message.tool_calls is None:
|
||||
return
|
||||
# Add any assertions here to check the response
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
|
@ -93,57 +93,6 @@ class TestFireworksAIChatCompletion(BaseLLMChatTest):
|
|||
"""
|
||||
pass
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_format",
|
||||
[
|
||||
{"type": "json_object"},
|
||||
{"type": "text"},
|
||||
],
|
||||
)
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
def test_json_response_format(self, response_format):
|
||||
"""
|
||||
Test that the JSON response format is supported by the LLM API
|
||||
"""
|
||||
from litellm.utils import supports_response_schema
|
||||
from openai import OpenAI
|
||||
from unittest.mock import patch
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
litellm.set_verbose = True
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your output should be a JSON object with no additional properties. ",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create"
|
||||
) as mock_post:
|
||||
response = self.completion_function(
|
||||
**base_completion_call_args,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
if response_format["type"] == "json_object":
|
||||
assert (
|
||||
mock_post.call_args.kwargs["response_format"]["type"]
|
||||
== "json_object"
|
||||
)
|
||||
else:
|
||||
assert mock_post.call_args.kwargs["response_format"]["type"] == "text"
|
||||
|
||||
|
||||
class TestFireworksAIAudioTranscription(BaseLLMAudioTranscriptionTest):
|
||||
def get_base_audio_transcription_call_args(self) -> dict:
|
||||
|
@ -253,14 +202,15 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch):
|
|||
from openai import OpenAI
|
||||
from unittest.mock import patch
|
||||
from litellm import completion
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
monkeypatch.setattr(litellm, "disable_add_transform_inline_image_block", True)
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response,
|
||||
"create",
|
||||
client,
|
||||
"post",
|
||||
) as mock_post:
|
||||
try:
|
||||
completion(
|
||||
|
@ -286,9 +236,10 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch):
|
|||
|
||||
mock_post.assert_called_once()
|
||||
print(mock_post.call_args.kwargs)
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert (
|
||||
"#transform=inline"
|
||||
not in mock_post.call_args.kwargs["messages"][0]["content"][1]["image_url"][
|
||||
not in json_data["messages"][0]["content"][1]["image_url"][
|
||||
"url"
|
||||
]
|
||||
)
|
||||
|
|
|
@ -4163,7 +4163,7 @@ def test_completion_vllm(provider):
|
|||
|
||||
|
||||
def test_completion_fireworks_ai_multiple_choices():
|
||||
litellm.set_verbose = True
|
||||
litellm._turn_on_debug()
|
||||
response = litellm.text_completion(
|
||||
model="fireworks_ai/llama-v3p1-8b-instruct",
|
||||
prompt=["halo", "hi", "halo", "hi"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue