diff --git a/litellm/llms/groq/chat/transformation.py b/litellm/llms/groq/chat/transformation.py index c3c470bc9..dddc56a2c 100644 --- a/litellm/llms/groq/chat/transformation.py +++ b/litellm/llms/groq/chat/transformation.py @@ -2,6 +2,7 @@ Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions` """ +import json import types from typing import List, Optional, Tuple, Union @@ -9,7 +10,12 @@ from pydantic import BaseModel import litellm from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig @@ -108,3 +114,60 @@ class GroqChatConfig(OpenAIGPTConfig): return True return False + + def _create_json_tool_call_for_response_format( + self, + json_schema: dict, + ): + """ + Handles creating a tool call for getting responses in JSON format. + + Args: + json_schema (Optional[dict]): The JSON schema the response should be in + + Returns: + AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format + """ + return ChatCompletionToolParam( + type="function", + function=ChatCompletionToolParamFunctionChunk( + name="json_tool_call", + parameters=json_schema, + ), + ) + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool = False, + ) -> dict: + _response_format = non_default_params.get("response_format") + if _response_format is not None and isinstance(_response_format, dict): + json_schema: Optional[dict] = None + if "response_schema" in _response_format: + json_schema = _response_format["response_schema"] + elif "json_schema" in _response_format: + json_schema = _response_format["json_schema"]["schema"] + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + if json_schema is not None: + _tool_choice = { + "type": "function", + "function": {"name": "json_tool_call"}, + } + _tool = self._create_json_tool_call_for_response_format( + json_schema=json_schema, + ) + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + non_default_params.pop("response_format", None) + return super().map_openai_params( + non_default_params, optional_params, model, drop_params + ) diff --git a/litellm/llms/openai_like/chat/handler.py b/litellm/llms/openai_like/chat/handler.py index 3246f82ae..6eb86561c 100644 --- a/litellm/llms/openai_like/chat/handler.py +++ b/litellm/llms/openai_like/chat/handler.py @@ -39,6 +39,7 @@ from litellm.utils import ( ) from ..common_utils import OpenAILikeBase, OpenAILikeError +from .transformation import OpenAILikeChatConfig async def make_call( @@ -190,6 +191,7 @@ class OpenAILikeChatHandler(OpenAILikeBase): logger_fn=None, headers={}, timeout: Optional[Union[float, httpx.Timeout]] = None, + json_mode: bool = False, ) -> ModelResponse: if timeout is None: timeout = httpx.Timeout(timeout=600.0, connect=5.0) @@ -202,8 +204,6 @@ class OpenAILikeChatHandler(OpenAILikeBase): api_base, headers=headers, data=json.dumps(data), timeout=timeout ) response.raise_for_status() - - response_json = response.json() except httpx.HTTPStatusError as e: raise OpenAILikeError( status_code=e.response.status_code, @@ -214,19 +214,22 @@ class OpenAILikeChatHandler(OpenAILikeBase): except Exception as e: raise OpenAILikeError(status_code=500, message=str(e)) - logging_obj.post_call( - input=messages, - api_key="", - original_response=response_json, - additional_args={"complete_input_dict": data}, + return OpenAILikeChatConfig._transform_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + json_mode=json_mode, + custom_llm_provider=custom_llm_provider, + base_model=base_model, ) - response = ModelResponse(**response_json) - - response.model = custom_llm_provider + "/" + (response.model or "") - - if base_model is not None: - response._hidden_params["model"] = base_model - return response def completion( self, @@ -268,6 +271,7 @@ class OpenAILikeChatHandler(OpenAILikeBase): stream: bool = optional_params.pop("stream", None) or False extra_body = optional_params.pop("extra_body", {}) + json_mode = optional_params.pop("json_mode", None) if not fake_stream: optional_params["stream"] = stream @@ -390,17 +394,19 @@ class OpenAILikeChatHandler(OpenAILikeBase): ) except Exception as e: raise OpenAILikeError(status_code=500, message=str(e)) - logging_obj.post_call( - input=messages, - api_key="", - original_response=response_json, - additional_args={"complete_input_dict": data}, + return OpenAILikeChatConfig._transform_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key=api_key, + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + json_mode=json_mode, + custom_llm_provider=custom_llm_provider, + base_model=base_model, ) - response = ModelResponse(**response_json) - - response.model = custom_llm_provider + "/" + (response.model or "") - - if base_model is not None: - response._hidden_params["model"] = base_model - - return response diff --git a/litellm/llms/openai_like/chat/transformation.py b/litellm/llms/openai_like/chat/transformation.py new file mode 100644 index 000000000..c355cf330 --- /dev/null +++ b/litellm/llms/openai_like/chat/transformation.py @@ -0,0 +1,98 @@ +""" +OpenAI-like chat completion transformation +""" + +import types +from typing import List, Optional, Tuple, Union + +import httpx +from pydantic import BaseModel + +import litellm +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage +from litellm.types.utils import ModelResponse + +from ....utils import _remove_additional_properties, _remove_strict_from_schema +from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig + + +class OpenAILikeChatConfig(OpenAIGPTConfig): + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = api_base or get_secret_str("OPENAI_LIKE_API_BASE") # type: ignore + dynamic_api_key = ( + api_key or get_secret_str("OPENAI_LIKE_API_KEY") or "" + ) # vllm does not require an api key + return api_base, dynamic_api_key + + @staticmethod + def _convert_tool_response_to_message( + message: ChatCompletionAssistantMessage, json_mode: bool + ) -> ChatCompletionAssistantMessage: + """ + if json_mode is true, convert the returned tool call response to a content with json str + + e.g. input: + + {"role": "assistant", "tool_calls": [{"id": "call_5ms4", "type": "function", "function": {"name": "json_tool_call", "arguments": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"}}]} + + output: + + {"role": "assistant", "content": "{\"key\": \"question\", \"value\": \"What is the capital of France?\"}"} + """ + if not json_mode: + return message + + _tool_calls = message.get("tool_calls") + + if _tool_calls is None or len(_tool_calls) != 1: + return message + + message["content"] = _tool_calls[0]["function"].get("arguments") or "" + message["tool_calls"] = None + + return message + + @staticmethod + def _transform_response( + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore + optional_params: dict, + api_key: Optional[str], + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + json_mode: bool, + custom_llm_provider: str, + base_model: Optional[str], + ) -> ModelResponse: + response_json = response.json() + logging_obj.post_call( + input=messages, + api_key="", + original_response=response_json, + additional_args={"complete_input_dict": data}, + ) + + if json_mode: + for choice in response_json["choices"]: + message = OpenAILikeChatConfig._convert_tool_response_to_message( + choice.get("message"), json_mode + ) + choice["message"] = message + + returned_response = ModelResponse(**response_json) + + returned_response.model = ( + custom_llm_provider + "/" + (returned_response.model or "") + ) + + if base_model is not None: + returned_response._hidden_params["model"] = base_model + return returned_response diff --git a/litellm/main.py b/litellm/main.py index 6cbf62d91..5d433eb36 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1495,7 +1495,6 @@ def completion( # type: ignore # noqa: PLR0915 timeout=timeout, # type: ignore custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, custom_llm_provider=custom_llm_provider, encoding=encoding, ) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 5e4f851e9..9758cf345 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1745,7 +1745,8 @@ "output_cost_per_token": 0.00000080, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-8b-8192": { "max_tokens": 8192, @@ -1755,7 +1756,8 @@ "output_cost_per_token": 0.00000008, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-70b-8192": { "max_tokens": 8192, @@ -1765,7 +1767,8 @@ "output_cost_per_token": 0.00000079, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama-3.1-8b-instant": { "max_tokens": 8192, @@ -1775,7 +1778,8 @@ "output_cost_per_token": 0.00000008, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama-3.1-70b-versatile": { "max_tokens": 8192, @@ -1785,7 +1789,8 @@ "output_cost_per_token": 0.00000079, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama-3.1-405b-reasoning": { "max_tokens": 8192, @@ -1795,7 +1800,8 @@ "output_cost_per_token": 0.00000079, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/mixtral-8x7b-32768": { "max_tokens": 32768, @@ -1805,7 +1811,8 @@ "output_cost_per_token": 0.00000024, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/gemma-7b-it": { "max_tokens": 8192, @@ -1815,7 +1822,8 @@ "output_cost_per_token": 0.00000007, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/gemma2-9b-it": { "max_tokens": 8192, @@ -1825,7 +1833,8 @@ "output_cost_per_token": 0.00000020, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-groq-70b-8192-tool-use-preview": { "max_tokens": 8192, @@ -1835,7 +1844,8 @@ "output_cost_per_token": 0.00000089, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-groq-8b-8192-tool-use-preview": { "max_tokens": 8192, @@ -1845,7 +1855,8 @@ "output_cost_per_token": 0.00000019, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "cerebras/llama3.1-8b": { "max_tokens": 128000, diff --git a/litellm/utils.py b/litellm/utils.py index bf82a154b..f9079c7b8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1739,15 +1739,15 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> Does not raise error. Defaults to 'False'. Outputs logging.error. """ + ## GET LLM PROVIDER ## + model, custom_llm_provider, _, _ = get_llm_provider( + model=model, custom_llm_provider=custom_llm_provider + ) + + if custom_llm_provider == "predibase": # predibase supports this globally + return True + try: - ## GET LLM PROVIDER ## - model, custom_llm_provider, _, _ = get_llm_provider( - model=model, custom_llm_provider=custom_llm_provider - ) - - if custom_llm_provider == "predibase": # predibase supports this globally - return True - ## GET MODEL INFO model_info = litellm.get_model_info( model=model, custom_llm_provider=custom_llm_provider @@ -1755,12 +1755,17 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) -> if model_info.get("supports_response_schema", False) is True: return True - return False except Exception: - verbose_logger.error( - f"Model not supports response_schema. You passed model={model}, custom_llm_provider={custom_llm_provider}." + ## check if provider supports response schema globally + supported_params = get_supported_openai_params( + model=model, + custom_llm_provider=custom_llm_provider, + request_type="chat_completion", ) - return False + if supported_params is not None and "response_schema" in supported_params: + return True + + return False def supports_function_calling( @@ -2710,6 +2715,7 @@ def get_optional_params( # noqa: PLR0915 non_default_params["response_format"] = type_to_response_format_param( response_format=non_default_params["response_format"] ) + if "tools" in non_default_params and isinstance( non_default_params, list ): # fixes https://github.com/BerriAI/litellm/issues/4933 @@ -3494,24 +3500,16 @@ def get_optional_params( # noqa: PLR0915 ) _check_valid_arg(supported_params=supported_params) - if temperature is not None: - optional_params["temperature"] = temperature - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if top_p is not None: - optional_params["top_p"] = top_p - if stream is not None: - optional_params["stream"] = stream - if stop is not None: - optional_params["stop"] = stop - if tools is not None: - optional_params["tools"] = tools - if tool_choice is not None: - optional_params["tool_choice"] = tool_choice - if response_format is not None: - optional_params["response_format"] = response_format - if seed is not None: - optional_params["seed"] = seed + optional_params = litellm.GroqChatConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif custom_llm_provider == "deepseek": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 5e4f851e9..9758cf345 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1745,7 +1745,8 @@ "output_cost_per_token": 0.00000080, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-8b-8192": { "max_tokens": 8192, @@ -1755,7 +1756,8 @@ "output_cost_per_token": 0.00000008, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-70b-8192": { "max_tokens": 8192, @@ -1765,7 +1767,8 @@ "output_cost_per_token": 0.00000079, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama-3.1-8b-instant": { "max_tokens": 8192, @@ -1775,7 +1778,8 @@ "output_cost_per_token": 0.00000008, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama-3.1-70b-versatile": { "max_tokens": 8192, @@ -1785,7 +1789,8 @@ "output_cost_per_token": 0.00000079, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama-3.1-405b-reasoning": { "max_tokens": 8192, @@ -1795,7 +1800,8 @@ "output_cost_per_token": 0.00000079, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/mixtral-8x7b-32768": { "max_tokens": 32768, @@ -1805,7 +1811,8 @@ "output_cost_per_token": 0.00000024, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/gemma-7b-it": { "max_tokens": 8192, @@ -1815,7 +1822,8 @@ "output_cost_per_token": 0.00000007, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/gemma2-9b-it": { "max_tokens": 8192, @@ -1825,7 +1833,8 @@ "output_cost_per_token": 0.00000020, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-groq-70b-8192-tool-use-preview": { "max_tokens": 8192, @@ -1835,7 +1844,8 @@ "output_cost_per_token": 0.00000089, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "groq/llama3-groq-8b-8192-tool-use-preview": { "max_tokens": 8192, @@ -1845,7 +1855,8 @@ "output_cost_per_token": 0.00000019, "litellm_provider": "groq", "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supports_response_schema": true }, "cerebras/llama3.1-8b": { "max_tokens": 128000, diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 5ad4a7c8c..2e4dea39b 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -93,6 +93,7 @@ class BaseLLMChatTest(ABC): assert response.choices[0].message.content is not None def test_json_response_pydantic_obj(self): + litellm.set_verbose = True from pydantic import BaseModel from litellm.utils import supports_response_schema @@ -119,6 +120,11 @@ class BaseLLMChatTest(ABC): response_format=TestModel, ) assert res is not None + + print(res.choices[0].message) + + assert res.choices[0].message.content is not None + assert res.choices[0].message.tool_calls is None except litellm.InternalServerError: pytest.skip("Model is overloaded") diff --git a/tests/llm_translation/test_groq.py b/tests/llm_translation/test_groq.py index 8522e65fa..359787b2d 100644 --- a/tests/llm_translation/test_groq.py +++ b/tests/llm_translation/test_groq.py @@ -4,7 +4,7 @@ from base_llm_unit_tests import BaseLLMChatTest class TestGroq(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: return { - "model": "groq/llama3-70b-8192", + "model": "groq/llama-3.1-70b-versatile", } def test_tool_call_no_arguments(self, tool_call_no_arguments): diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 6e7b0ff05..52946ca30 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -749,6 +749,7 @@ def test_convert_model_response_object(): ("gemini/gemini-1.5-pro", True), ("predibase/llama3-8b-instruct", True), ("gpt-3.5-turbo", False), + ("groq/llama3-70b-8192", True), ], ) def test_supports_response_schema(model, expected_bool):