Handle fireworks ai tool calling response (#10130)

* feat(fireworks_ai/chat): handle tool calling with fireworks ai correctly

Fixes https://github.com/BerriAI/litellm/issues/7209

* fix(utils.py): handle none type in message

* fix: fix model name in test

* fix(utils.py): fix validate check for openai messages

* fix: fix model returned

* fix(main.py): fix text completion routing

* test: update testing

* test: skip test - cohere having RBAC issues
This commit is contained in:
Krish Dholakia 2025-04-19 09:37:45 -07:00 committed by GitHub
parent b4f2b3dad1
commit 2508ca71cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 242 additions and 74 deletions

View file

@ -229,13 +229,17 @@ class BaseLLMHTTPHandler:
api_key: Optional[str] = None,
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
provider_config: Optional[BaseConfig] = None,
):
json_mode: bool = optional_params.pop("json_mode", False)
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
fake_stream = fake_stream or optional_params.pop("fake_stream", False)
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
provider_config = (
provider_config
or ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
)
if provider_config is None:
raise ValueError(

View file

@ -1,15 +1,33 @@
from typing import List, Literal, Optional, Tuple, Union, cast
import json
import uuid
from typing import Any, List, Literal, Optional, Tuple, Union, cast
import httpx
import litellm
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
get_response_headers,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionImageObject,
ChatCompletionToolParam,
OpenAIChatCompletionToolParam,
)
from litellm.types.utils import ProviderSpecificModelInfo
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
Function,
Message,
ModelResponse,
ProviderSpecificModelInfo,
)
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import FireworksAIException
class FireworksAIConfig(OpenAIGPTConfig):
@ -219,6 +237,94 @@ class FireworksAIConfig(OpenAIGPTConfig):
headers=headers,
)
def _handle_message_content_with_tool_calls(
self,
message: Message,
tool_calls: Optional[List[ChatCompletionToolParam]],
) -> Message:
"""
Fireworks AI sends tool calls in the content field instead of tool_calls
Relevant Issue: https://github.com/BerriAI/litellm/issues/7209#issuecomment-2813208780
"""
if (
tool_calls is not None
and message.content is not None
and message.tool_calls is None
):
try:
function = Function(**json.loads(message.content))
if function.name != RESPONSE_FORMAT_TOOL_NAME and function.name in [
tool["function"]["name"] for tool in tool_calls
]:
tool_call = ChatCompletionMessageToolCall(
function=function, id=str(uuid.uuid4()), type="function"
)
message.tool_calls = [tool_call]
message.content = None
except Exception:
pass
return message
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
except Exception as e:
response_headers = getattr(raw_response, "headers", None)
raise FireworksAIException(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), raw_response.text
),
status_code=raw_response.status_code,
headers=response_headers,
)
raw_response_headers = dict(raw_response.headers)
additional_headers = get_response_headers(raw_response_headers)
response = ModelResponse(**completion_response)
if response.model is not None:
response.model = "fireworks_ai/" + response.model
## FIREWORKS AI sends tool calls in the content field instead of tool_calls
for choice in response.choices:
cast(
Choices, choice
).message = self._handle_message_content_with_tool_calls(
message=cast(Choices, choice).message,
tool_calls=optional_params.get("tools", None),
)
response._hidden_params = {"additional_headers": additional_headers}
return response
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:

View file

@ -1435,6 +1435,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
provider_config=provider_config,
)
except Exception as e:
## LOGGING - log the original exception returned
@ -1596,6 +1597,37 @@ def completion( # type: ignore # noqa: PLR0915
additional_args={"headers": headers},
)
response = _response
elif custom_llm_provider == "fireworks_ai":
## COMPLETION CALL
try:
response = base_llm_http_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout, # type: ignore
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
provider_config=provider_config,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e
elif custom_llm_provider == "groq":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there

View file

@ -378,12 +378,18 @@ class Function(OpenAIObject):
def __init__(
self,
arguments: Optional[Union[Dict, str]],
arguments: Optional[Union[Dict, str]] = None,
name: Optional[str] = None,
**params,
):
if arguments is None:
arguments = ""
if params.get("parameters", None) is not None and isinstance(
params["parameters"], dict
):
arguments = json.dumps(params["parameters"])
params.pop("parameters")
else:
arguments = ""
elif isinstance(arguments, Dict):
arguments = json.dumps(arguments)
else:
@ -392,7 +398,7 @@ class Function(OpenAIObject):
name = name
# Build a dictionary with the structure your BaseModel expects
data = {"arguments": arguments, "name": name, **params}
data = {"arguments": arguments, "name": name}
super(Function, self).__init__(**data)

View file

@ -6264,24 +6264,27 @@ def validate_and_fix_openai_messages(messages: List):
Handles missing role for assistant messages.
"""
new_messages = []
for message in messages:
if not message.get("role"):
message["role"] = "assistant"
if message.get("tool_calls"):
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
return validate_chat_completion_messages(messages=messages)
convert_msg_to_dict = cast(AllMessageValues, convert_to_dict(message))
cleaned_message = cleanup_none_field_in_message(message=convert_msg_to_dict)
new_messages.append(cleaned_message)
return validate_chat_completion_user_messages(messages=new_messages)
def validate_chat_completion_messages(messages: List[AllMessageValues]):
def cleanup_none_field_in_message(message: AllMessageValues):
"""
Ensures all messages are valid OpenAI chat completion messages.
Cleans up the message by removing the none field.
remove None fields in the message - e.g. {"function": None} - some providers raise validation errors
"""
# 1. convert all messages to dict
messages = [
cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages
]
# 2. validate user messages
return validate_chat_completion_user_messages(messages=messages)
new_message = message.copy()
return {k: v for k, v in new_message.items() if v is not None}
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):

View file

@ -0,0 +1,59 @@
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
def test_handle_message_content_with_tool_calls():
config = FireworksAIConfig()
message = Message(
content='{"type": "function", "name": "get_current_weather", "parameters": {"location": "Boston, MA", "unit": "fahrenheit"}}',
role="assistant",
tool_calls=None,
function_call=None,
provider_specific_fields=None,
)
expected_tool_call = ChatCompletionMessageToolCall(
function=Function(**json.loads(message.content)), id=None, type=None
)
tool_calls = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
updated_message = config._handle_message_content_with_tool_calls(
message, tool_calls
)
assert updated_message.tool_calls is not None
assert len(updated_message.tool_calls) == 1
assert updated_message.tool_calls[0].function.name == "get_current_weather"
assert (
updated_message.tool_calls[0].function.arguments
== expected_tool_call.function.arguments
)

View file

@ -896,6 +896,13 @@ class BaseLLMChatTest(ABC):
assert response is not None
# if the provider did not return any tool calls do not make a subsequent llm api call
if response.choices[0].message.content is not None:
try:
json.loads(response.choices[0].message.content)
pytest.fail(f"Tool call returned in content instead of tool_calls")
except Exception as e:
print(f"Error: {e}")
pass
if response.choices[0].message.tool_calls is None:
return
# Add any assertions here to check the response

View file

@ -1,6 +1,6 @@
import os
import sys
import json
import pytest
sys.path.insert(
@ -93,57 +93,6 @@ class TestFireworksAIChatCompletion(BaseLLMChatTest):
"""
pass
@pytest.mark.parametrize(
"response_format",
[
{"type": "json_object"},
{"type": "text"},
],
)
@pytest.mark.flaky(retries=6, delay=1)
def test_json_response_format(self, response_format):
"""
Test that the JSON response format is supported by the LLM API
"""
from litellm.utils import supports_response_schema
from openai import OpenAI
from unittest.mock import patch
client = OpenAI()
base_completion_call_args = self.get_base_completion_call_args()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your output should be a JSON object with no additional properties. ",
},
{
"role": "user",
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
},
]
with patch.object(
client.chat.completions.with_raw_response, "create"
) as mock_post:
response = self.completion_function(
**base_completion_call_args,
messages=messages,
response_format=response_format,
client=client,
)
mock_post.assert_called_once()
if response_format["type"] == "json_object":
assert (
mock_post.call_args.kwargs["response_format"]["type"]
== "json_object"
)
else:
assert mock_post.call_args.kwargs["response_format"]["type"] == "text"
class TestFireworksAIAudioTranscription(BaseLLMAudioTranscriptionTest):
def get_base_audio_transcription_call_args(self) -> dict:
@ -253,14 +202,15 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch):
from openai import OpenAI
from unittest.mock import patch
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
monkeypatch.setattr(litellm, "disable_add_transform_inline_image_block", True)
client = OpenAI()
with patch.object(
client.chat.completions.with_raw_response,
"create",
client,
"post",
) as mock_post:
try:
completion(
@ -286,9 +236,10 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch):
mock_post.assert_called_once()
print(mock_post.call_args.kwargs)
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert (
"#transform=inline"
not in mock_post.call_args.kwargs["messages"][0]["content"][1]["image_url"][
not in json_data["messages"][0]["content"][1]["image_url"][
"url"
]
)

View file

@ -4163,7 +4163,7 @@ def test_completion_vllm(provider):
def test_completion_fireworks_ai_multiple_choices():
litellm.set_verbose = True
litellm._turn_on_debug()
response = litellm.text_completion(
model="fireworks_ai/llama-v3p1-8b-instruct",
prompt=["halo", "hi", "halo", "hi"],