mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Handle fireworks ai tool calling response (#10130)
* feat(fireworks_ai/chat): handle tool calling with fireworks ai correctly Fixes https://github.com/BerriAI/litellm/issues/7209 * fix(utils.py): handle none type in message * fix: fix model name in test * fix(utils.py): fix validate check for openai messages * fix: fix model returned * fix(main.py): fix text completion routing * test: update testing * test: skip test - cohere having RBAC issues
This commit is contained in:
parent
4663a66b47
commit
e122f2df56
9 changed files with 242 additions and 74 deletions
|
@ -229,14 +229,18 @@ class BaseLLMHTTPHandler:
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
headers: Optional[dict] = {},
|
headers: Optional[dict] = {},
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
provider_config: Optional[BaseConfig] = None,
|
||||||
):
|
):
|
||||||
json_mode: bool = optional_params.pop("json_mode", False)
|
json_mode: bool = optional_params.pop("json_mode", False)
|
||||||
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
extra_body: Optional[dict] = optional_params.pop("extra_body", None)
|
||||||
fake_stream = fake_stream or optional_params.pop("fake_stream", False)
|
fake_stream = fake_stream or optional_params.pop("fake_stream", False)
|
||||||
|
|
||||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
provider_config = (
|
||||||
|
provider_config
|
||||||
|
or ProviderConfigManager.get_provider_chat_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if provider_config is None:
|
if provider_config is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||||
|
|
|
@ -1,15 +1,33 @@
|
||||||
from typing import List, Literal, Optional, Tuple, Union, cast
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, List, Literal, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||||
|
get_response_headers,
|
||||||
|
)
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionImageObject,
|
ChatCompletionImageObject,
|
||||||
|
ChatCompletionToolParam,
|
||||||
OpenAIChatCompletionToolParam,
|
OpenAIChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import ProviderSpecificModelInfo
|
from litellm.types.utils import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
Choices,
|
||||||
|
Function,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderSpecificModelInfo,
|
||||||
|
)
|
||||||
|
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from ..common_utils import FireworksAIException
|
||||||
|
|
||||||
|
|
||||||
class FireworksAIConfig(OpenAIGPTConfig):
|
class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
|
@ -219,6 +237,94 @@ class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _handle_message_content_with_tool_calls(
|
||||||
|
self,
|
||||||
|
message: Message,
|
||||||
|
tool_calls: Optional[List[ChatCompletionToolParam]],
|
||||||
|
) -> Message:
|
||||||
|
"""
|
||||||
|
Fireworks AI sends tool calls in the content field instead of tool_calls
|
||||||
|
|
||||||
|
Relevant Issue: https://github.com/BerriAI/litellm/issues/7209#issuecomment-2813208780
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
tool_calls is not None
|
||||||
|
and message.content is not None
|
||||||
|
and message.tool_calls is None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
function = Function(**json.loads(message.content))
|
||||||
|
if function.name != RESPONSE_FORMAT_TOOL_NAME and function.name in [
|
||||||
|
tool["function"]["name"] for tool in tool_calls
|
||||||
|
]:
|
||||||
|
tool_call = ChatCompletionMessageToolCall(
|
||||||
|
function=function, id=str(uuid.uuid4()), type="function"
|
||||||
|
)
|
||||||
|
message.tool_calls = [tool_call]
|
||||||
|
|
||||||
|
message.content = None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=raw_response.text,
|
||||||
|
additional_args={"complete_input_dict": request_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = raw_response.json()
|
||||||
|
except Exception as e:
|
||||||
|
response_headers = getattr(raw_response, "headers", None)
|
||||||
|
raise FireworksAIException(
|
||||||
|
message="Unable to get json response - {}, Original Response: {}".format(
|
||||||
|
str(e), raw_response.text
|
||||||
|
),
|
||||||
|
status_code=raw_response.status_code,
|
||||||
|
headers=response_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_response_headers = dict(raw_response.headers)
|
||||||
|
|
||||||
|
additional_headers = get_response_headers(raw_response_headers)
|
||||||
|
|
||||||
|
response = ModelResponse(**completion_response)
|
||||||
|
|
||||||
|
if response.model is not None:
|
||||||
|
response.model = "fireworks_ai/" + response.model
|
||||||
|
|
||||||
|
## FIREWORKS AI sends tool calls in the content field instead of tool_calls
|
||||||
|
for choice in response.choices:
|
||||||
|
cast(
|
||||||
|
Choices, choice
|
||||||
|
).message = self._handle_message_content_with_tool_calls(
|
||||||
|
message=cast(Choices, choice).message,
|
||||||
|
tool_calls=optional_params.get("tools", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
response._hidden_params = {"additional_headers": additional_headers}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def _get_openai_compatible_provider_info(
|
def _get_openai_compatible_provider_info(
|
||||||
self, api_base: Optional[str], api_key: Optional[str]
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
|
|
@ -1435,6 +1435,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
provider_config=provider_config,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING - log the original exception returned
|
## LOGGING - log the original exception returned
|
||||||
|
@ -1596,6 +1597,37 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
additional_args={"headers": headers},
|
additional_args={"headers": headers},
|
||||||
)
|
)
|
||||||
response = _response
|
response = _response
|
||||||
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
|
## COMPLETION CALL
|
||||||
|
try:
|
||||||
|
response = base_llm_http_handler.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
timeout=timeout, # type: ignore
|
||||||
|
client=client,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
encoding=encoding,
|
||||||
|
stream=stream,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
## LOGGING - log the original exception returned
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=str(e),
|
||||||
|
additional_args={"headers": headers},
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
elif custom_llm_provider == "groq":
|
elif custom_llm_provider == "groq":
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||||
|
|
|
@ -378,11 +378,17 @@ class Function(OpenAIObject):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
arguments: Optional[Union[Dict, str]],
|
arguments: Optional[Union[Dict, str]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
if arguments is None:
|
if arguments is None:
|
||||||
|
if params.get("parameters", None) is not None and isinstance(
|
||||||
|
params["parameters"], dict
|
||||||
|
):
|
||||||
|
arguments = json.dumps(params["parameters"])
|
||||||
|
params.pop("parameters")
|
||||||
|
else:
|
||||||
arguments = ""
|
arguments = ""
|
||||||
elif isinstance(arguments, Dict):
|
elif isinstance(arguments, Dict):
|
||||||
arguments = json.dumps(arguments)
|
arguments = json.dumps(arguments)
|
||||||
|
@ -392,7 +398,7 @@ class Function(OpenAIObject):
|
||||||
name = name
|
name = name
|
||||||
|
|
||||||
# Build a dictionary with the structure your BaseModel expects
|
# Build a dictionary with the structure your BaseModel expects
|
||||||
data = {"arguments": arguments, "name": name, **params}
|
data = {"arguments": arguments, "name": name}
|
||||||
|
|
||||||
super(Function, self).__init__(**data)
|
super(Function, self).__init__(**data)
|
||||||
|
|
||||||
|
|
|
@ -6264,24 +6264,27 @@ def validate_and_fix_openai_messages(messages: List):
|
||||||
|
|
||||||
Handles missing role for assistant messages.
|
Handles missing role for assistant messages.
|
||||||
"""
|
"""
|
||||||
|
new_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if not message.get("role"):
|
if not message.get("role"):
|
||||||
message["role"] = "assistant"
|
message["role"] = "assistant"
|
||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
|
message["tool_calls"] = jsonify_tools(tools=message["tool_calls"])
|
||||||
return validate_chat_completion_messages(messages=messages)
|
|
||||||
|
convert_msg_to_dict = cast(AllMessageValues, convert_to_dict(message))
|
||||||
|
cleaned_message = cleanup_none_field_in_message(message=convert_msg_to_dict)
|
||||||
|
new_messages.append(cleaned_message)
|
||||||
|
return validate_chat_completion_user_messages(messages=new_messages)
|
||||||
|
|
||||||
|
|
||||||
def validate_chat_completion_messages(messages: List[AllMessageValues]):
|
def cleanup_none_field_in_message(message: AllMessageValues):
|
||||||
"""
|
"""
|
||||||
Ensures all messages are valid OpenAI chat completion messages.
|
Cleans up the message by removing the none field.
|
||||||
|
|
||||||
|
remove None fields in the message - e.g. {"function": None} - some providers raise validation errors
|
||||||
"""
|
"""
|
||||||
# 1. convert all messages to dict
|
new_message = message.copy()
|
||||||
messages = [
|
return {k: v for k, v in new_message.items() if v is not None}
|
||||||
cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages
|
|
||||||
]
|
|
||||||
# 2. validate user messages
|
|
||||||
return validate_chat_completion_user_messages(messages=messages)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
||||||
|
from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk
|
||||||
|
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_message_content_with_tool_calls():
|
||||||
|
config = FireworksAIConfig()
|
||||||
|
message = Message(
|
||||||
|
content='{"type": "function", "name": "get_current_weather", "parameters": {"location": "Boston, MA", "unit": "fahrenheit"}}',
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=None,
|
||||||
|
function_call=None,
|
||||||
|
provider_specific_fields=None,
|
||||||
|
)
|
||||||
|
expected_tool_call = ChatCompletionMessageToolCall(
|
||||||
|
function=Function(**json.loads(message.content)), id=None, type=None
|
||||||
|
)
|
||||||
|
tool_calls = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
updated_message = config._handle_message_content_with_tool_calls(
|
||||||
|
message, tool_calls
|
||||||
|
)
|
||||||
|
assert updated_message.tool_calls is not None
|
||||||
|
assert len(updated_message.tool_calls) == 1
|
||||||
|
assert updated_message.tool_calls[0].function.name == "get_current_weather"
|
||||||
|
assert (
|
||||||
|
updated_message.tool_calls[0].function.arguments
|
||||||
|
== expected_tool_call.function.arguments
|
||||||
|
)
|
|
@ -896,6 +896,13 @@ class BaseLLMChatTest(ABC):
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
# if the provider did not return any tool calls do not make a subsequent llm api call
|
# if the provider did not return any tool calls do not make a subsequent llm api call
|
||||||
|
if response.choices[0].message.content is not None:
|
||||||
|
try:
|
||||||
|
json.loads(response.choices[0].message.content)
|
||||||
|
pytest.fail(f"Tool call returned in content instead of tool_calls")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
pass
|
||||||
if response.choices[0].message.tool_calls is None:
|
if response.choices[0].message.tool_calls is None:
|
||||||
return
|
return
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -93,57 +93,6 @@ class TestFireworksAIChatCompletion(BaseLLMChatTest):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"response_format",
|
|
||||||
[
|
|
||||||
{"type": "json_object"},
|
|
||||||
{"type": "text"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@pytest.mark.flaky(retries=6, delay=1)
|
|
||||||
def test_json_response_format(self, response_format):
|
|
||||||
"""
|
|
||||||
Test that the JSON response format is supported by the LLM API
|
|
||||||
"""
|
|
||||||
from litellm.utils import supports_response_schema
|
|
||||||
from openai import OpenAI
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
base_completion_call_args = self.get_base_completion_call_args()
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "Your output should be a JSON object with no additional properties. ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
client.chat.completions.with_raw_response, "create"
|
|
||||||
) as mock_post:
|
|
||||||
response = self.completion_function(
|
|
||||||
**base_completion_call_args,
|
|
||||||
messages=messages,
|
|
||||||
response_format=response_format,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_post.assert_called_once()
|
|
||||||
if response_format["type"] == "json_object":
|
|
||||||
assert (
|
|
||||||
mock_post.call_args.kwargs["response_format"]["type"]
|
|
||||||
== "json_object"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert mock_post.call_args.kwargs["response_format"]["type"] == "text"
|
|
||||||
|
|
||||||
|
|
||||||
class TestFireworksAIAudioTranscription(BaseLLMAudioTranscriptionTest):
|
class TestFireworksAIAudioTranscription(BaseLLMAudioTranscriptionTest):
|
||||||
def get_base_audio_transcription_call_args(self) -> dict:
|
def get_base_audio_transcription_call_args(self) -> dict:
|
||||||
|
@ -253,14 +202,15 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch):
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
monkeypatch.setattr(litellm, "disable_add_transform_inline_image_block", True)
|
monkeypatch.setattr(litellm, "disable_add_transform_inline_image_block", True)
|
||||||
|
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
client.chat.completions.with_raw_response,
|
client,
|
||||||
"create",
|
"post",
|
||||||
) as mock_post:
|
) as mock_post:
|
||||||
try:
|
try:
|
||||||
completion(
|
completion(
|
||||||
|
@ -286,9 +236,10 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch):
|
||||||
|
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
print(mock_post.call_args.kwargs)
|
print(mock_post.call_args.kwargs)
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
assert (
|
assert (
|
||||||
"#transform=inline"
|
"#transform=inline"
|
||||||
not in mock_post.call_args.kwargs["messages"][0]["content"][1]["image_url"][
|
not in json_data["messages"][0]["content"][1]["image_url"][
|
||||||
"url"
|
"url"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -4163,7 +4163,7 @@ def test_completion_vllm(provider):
|
||||||
|
|
||||||
|
|
||||||
def test_completion_fireworks_ai_multiple_choices():
|
def test_completion_fireworks_ai_multiple_choices():
|
||||||
litellm.set_verbose = True
|
litellm._turn_on_debug()
|
||||||
response = litellm.text_completion(
|
response = litellm.text_completion(
|
||||||
model="fireworks_ai/llama-v3p1-8b-instruct",
|
model="fireworks_ai/llama-v3p1-8b-instruct",
|
||||||
prompt=["halo", "hi", "halo", "hi"],
|
prompt=["halo", "hi", "halo", "hi"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue