diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 6dd47cc223..c7b18215d0 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -229,13 +229,17 @@ class BaseLLMHTTPHandler: api_key: Optional[str] = None, headers: Optional[dict] = {}, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + provider_config: Optional[BaseConfig] = None, ): json_mode: bool = optional_params.pop("json_mode", False) extra_body: Optional[dict] = optional_params.pop("extra_body", None) fake_stream = fake_stream or optional_params.pop("fake_stream", False) - provider_config = ProviderConfigManager.get_provider_chat_config( - model=model, provider=litellm.LlmProviders(custom_llm_provider) + provider_config = ( + provider_config + or ProviderConfigManager.get_provider_chat_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) ) if provider_config is None: raise ValueError( diff --git a/litellm/llms/fireworks_ai/chat/transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py index dc78c5bc5d..2a795bdf2f 100644 --- a/litellm/llms/fireworks_ai/chat/transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -1,15 +1,33 @@ -from typing import List, Literal, Optional, Tuple, Union, cast +import json +import uuid +from typing import Any, List, Literal, Optional, Tuple, Union, cast + +import httpx import litellm +from litellm.constants import RESPONSE_FORMAT_TOOL_NAME +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.llm_response_utils.get_headers import ( + get_response_headers, +) from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import ( AllMessageValues, ChatCompletionImageObject, + ChatCompletionToolParam, OpenAIChatCompletionToolParam, ) -from litellm.types.utils import ProviderSpecificModelInfo +from litellm.types.utils import ( + ChatCompletionMessageToolCall, + Choices, + Function, + Message, + ModelResponse, + ProviderSpecificModelInfo, +) from ...openai.chat.gpt_transformation import OpenAIGPTConfig +from ..common_utils import FireworksAIException class FireworksAIConfig(OpenAIGPTConfig): @@ -219,6 +237,94 @@ class FireworksAIConfig(OpenAIGPTConfig): headers=headers, ) + def _handle_message_content_with_tool_calls( + self, + message: Message, + tool_calls: Optional[List[ChatCompletionToolParam]], + ) -> Message: + """ + Fireworks AI sends tool calls in the content field instead of tool_calls + + Relevant Issue: https://github.com/BerriAI/litellm/issues/7209#issuecomment-2813208780 + """ + if ( + tool_calls is not None + and message.content is not None + and message.tool_calls is None + ): + try: + function = Function(**json.loads(message.content)) + if function.name != RESPONSE_FORMAT_TOOL_NAME and function.name in [ + tool["function"]["name"] for tool in tool_calls + ]: + tool_call = ChatCompletionMessageToolCall( + function=function, id=str(uuid.uuid4()), type="function" + ) + message.tool_calls = [tool_call] + + message.content = None + except Exception: + pass + + return message + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=raw_response.text, + additional_args={"complete_input_dict": request_data}, + ) + + ## RESPONSE OBJECT + try: + completion_response = raw_response.json() + except Exception as e: + response_headers = getattr(raw_response, "headers", None) + raise FireworksAIException( + message="Unable to get json response - {}, Original Response: {}".format( + str(e), raw_response.text + ), + status_code=raw_response.status_code, + headers=response_headers, + ) + + raw_response_headers = dict(raw_response.headers) + + additional_headers = get_response_headers(raw_response_headers) + + response = ModelResponse(**completion_response) + + if response.model is not None: + response.model = "fireworks_ai/" + response.model + + ## FIREWORKS AI sends tool calls in the content field instead of tool_calls + for choice in response.choices: + cast( + Choices, choice + ).message = self._handle_message_content_with_tool_calls( + message=cast(Choices, choice).message, + tool_calls=optional_params.get("tools", None), + ) + + response._hidden_params = {"additional_headers": additional_headers} + + return response + def _get_openai_compatible_provider_info( self, api_base: Optional[str], api_key: Optional[str] ) -> Tuple[Optional[str], Optional[str]]: diff --git a/litellm/main.py b/litellm/main.py index 1dffb87f43..9bb1cf0c15 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1435,6 +1435,7 @@ def completion( # type: ignore # noqa: PLR0915 custom_llm_provider=custom_llm_provider, encoding=encoding, stream=stream, + provider_config=provider_config, ) except Exception as e: ## LOGGING - log the original exception returned @@ -1596,6 +1597,37 @@ def completion( # type: ignore # noqa: PLR0915 additional_args={"headers": headers}, ) response = _response + elif custom_llm_provider == "fireworks_ai": + ## COMPLETION CALL + try: + response = base_llm_http_handler.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + timeout=timeout, # type: ignore + client=client, + custom_llm_provider=custom_llm_provider, + encoding=encoding, + stream=stream, + provider_config=provider_config, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + elif custom_llm_provider == "groq": api_base = ( api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there diff --git a/litellm/types/utils.py b/litellm/types/utils.py index ac626d4657..88f9638438 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -378,12 +378,18 @@ class Function(OpenAIObject): def __init__( self, - arguments: Optional[Union[Dict, str]], + arguments: Optional[Union[Dict, str]] = None, name: Optional[str] = None, **params, ): if arguments is None: - arguments = "" + if params.get("parameters", None) is not None and isinstance( + params["parameters"], dict + ): + arguments = json.dumps(params["parameters"]) + params.pop("parameters") + else: + arguments = "" elif isinstance(arguments, Dict): arguments = json.dumps(arguments) else: @@ -392,7 +398,7 @@ class Function(OpenAIObject): name = name # Build a dictionary with the structure your BaseModel expects - data = {"arguments": arguments, "name": name, **params} + data = {"arguments": arguments, "name": name} super(Function, self).__init__(**data) diff --git a/litellm/utils.py b/litellm/utils.py index 6b52fe91fa..3efd188717 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6264,24 +6264,27 @@ def validate_and_fix_openai_messages(messages: List): Handles missing role for assistant messages. """ + new_messages = [] for message in messages: if not message.get("role"): message["role"] = "assistant" if message.get("tool_calls"): message["tool_calls"] = jsonify_tools(tools=message["tool_calls"]) - return validate_chat_completion_messages(messages=messages) + + convert_msg_to_dict = cast(AllMessageValues, convert_to_dict(message)) + cleaned_message = cleanup_none_field_in_message(message=convert_msg_to_dict) + new_messages.append(cleaned_message) + return validate_chat_completion_user_messages(messages=new_messages) -def validate_chat_completion_messages(messages: List[AllMessageValues]): +def cleanup_none_field_in_message(message: AllMessageValues): """ - Ensures all messages are valid OpenAI chat completion messages. + Cleans up the message by removing the none field. + + remove None fields in the message - e.g. {"function": None} - some providers raise validation errors """ - # 1. convert all messages to dict - messages = [ - cast(AllMessageValues, convert_to_dict(cast(dict, m))) for m in messages - ] - # 2. validate user messages - return validate_chat_completion_user_messages(messages=messages) + new_message = message.copy() + return {k: v for k, v in new_message.items() if v is not None} def validate_chat_completion_user_messages(messages: List[AllMessageValues]): diff --git a/tests/litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py b/tests/litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py new file mode 100644 index 0000000000..e4b0928d92 --- /dev/null +++ b/tests/litellm/llms/fireworks_ai/chat/test_fireworks_ai_chat_transformation.py @@ -0,0 +1,59 @@ +import json +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +sys.path.insert( + 0, os.path.abspath("../../../../..") +) # Adds the parent directory to the system path + +from litellm.llms.fireworks_ai.chat.transformation import FireworksAIConfig +from litellm.types.llms.openai import ChatCompletionToolCallFunctionChunk +from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message + + +def test_handle_message_content_with_tool_calls(): + config = FireworksAIConfig() + message = Message( + content='{"type": "function", "name": "get_current_weather", "parameters": {"location": "Boston, MA", "unit": "fahrenheit"}}', + role="assistant", + tool_calls=None, + function_call=None, + provider_specific_fields=None, + ) + expected_tool_call = ChatCompletionMessageToolCall( + function=Function(**json.loads(message.content)), id=None, type=None + ) + tool_calls = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + updated_message = config._handle_message_content_with_tool_calls( + message, tool_calls + ) + assert updated_message.tool_calls is not None + assert len(updated_message.tool_calls) == 1 + assert updated_message.tool_calls[0].function.name == "get_current_weather" + assert ( + updated_message.tool_calls[0].function.arguments + == expected_tool_call.function.arguments + ) diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 725ebfba59..bbdb8e776f 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -896,6 +896,13 @@ class BaseLLMChatTest(ABC): assert response is not None # if the provider did not return any tool calls do not make a subsequent llm api call + if response.choices[0].message.content is not None: + try: + json.loads(response.choices[0].message.content) + pytest.fail(f"Tool call returned in content instead of tool_calls") + except Exception as e: + print(f"Error: {e}") + pass if response.choices[0].message.tool_calls is None: return # Add any assertions here to check the response diff --git a/tests/llm_translation/test_fireworks_ai_translation.py b/tests/llm_translation/test_fireworks_ai_translation.py index 9e78270c92..953405c46e 100644 --- a/tests/llm_translation/test_fireworks_ai_translation.py +++ b/tests/llm_translation/test_fireworks_ai_translation.py @@ -1,6 +1,6 @@ import os import sys - +import json import pytest sys.path.insert( @@ -93,57 +93,6 @@ class TestFireworksAIChatCompletion(BaseLLMChatTest): """ pass - @pytest.mark.parametrize( - "response_format", - [ - {"type": "json_object"}, - {"type": "text"}, - ], - ) - @pytest.mark.flaky(retries=6, delay=1) - def test_json_response_format(self, response_format): - """ - Test that the JSON response format is supported by the LLM API - """ - from litellm.utils import supports_response_schema - from openai import OpenAI - from unittest.mock import patch - - client = OpenAI() - - base_completion_call_args = self.get_base_completion_call_args() - litellm.set_verbose = True - - messages = [ - { - "role": "system", - "content": "Your output should be a JSON object with no additional properties. ", - }, - { - "role": "user", - "content": "Respond with this in json. city=San Francisco, state=CA, weather=sunny, temp=60", - }, - ] - - with patch.object( - client.chat.completions.with_raw_response, "create" - ) as mock_post: - response = self.completion_function( - **base_completion_call_args, - messages=messages, - response_format=response_format, - client=client, - ) - - mock_post.assert_called_once() - if response_format["type"] == "json_object": - assert ( - mock_post.call_args.kwargs["response_format"]["type"] - == "json_object" - ) - else: - assert mock_post.call_args.kwargs["response_format"]["type"] == "text" - class TestFireworksAIAudioTranscription(BaseLLMAudioTranscriptionTest): def get_base_audio_transcription_call_args(self) -> dict: @@ -253,14 +202,15 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch): from openai import OpenAI from unittest.mock import patch from litellm import completion + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() monkeypatch.setattr(litellm, "disable_add_transform_inline_image_block", True) - client = OpenAI() - with patch.object( - client.chat.completions.with_raw_response, - "create", + client, + "post", ) as mock_post: try: completion( @@ -286,9 +236,10 @@ def test_global_disable_flag_with_transform_messages_helper(monkeypatch): mock_post.assert_called_once() print(mock_post.call_args.kwargs) + json_data = json.loads(mock_post.call_args.kwargs["data"]) assert ( "#transform=inline" - not in mock_post.call_args.kwargs["messages"][0]["content"][1]["image_url"][ + not in json_data["messages"][0]["content"][1]["image_url"][ "url" ] ) diff --git a/tests/local_testing/test_text_completion.py b/tests/local_testing/test_text_completion.py index 675f941b0e..35929f6f35 100644 --- a/tests/local_testing/test_text_completion.py +++ b/tests/local_testing/test_text_completion.py @@ -4163,7 +4163,7 @@ def test_completion_vllm(provider): def test_completion_fireworks_ai_multiple_choices(): - litellm.set_verbose = True + litellm._turn_on_debug() response = litellm.text_completion( model="fireworks_ai/llama-v3p1-8b-instruct", prompt=["halo", "hi", "halo", "hi"],