diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 7acbdfae02..d12675d6b5 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -234,6 +234,47 @@ class OpenAIConfig: and v is not None } + def get_supported_openai_params(self, model: str) -> list: + base_params = [ + "frequency_penalty", + "logit_bias", + "logprobs", + "top_logprobs", + "max_tokens", + "n", + "presence_penalty", + "seed", + "stop", + "stream", + "stream_options", + "temperature", + "top_p", + "tools", + "tool_choice", + "user", + "function_call", + "functions", + "max_retries", + "extra_headers", + ] # works across all models + + model_specific_params = [] + if ( + "gpt-3.5-turbo" in model or "gpt-4-turbo" in model or "gpt-4o" in model + ): # gpt-4 does not support 'response_format' + model_specific_params.append("response_format") + + return base_params + model_specific_params + + def map_openai_params( + self, non_default_params: dict, optional_params: dict, model: str + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model) + for param, value in non_default_params.items(): + if param in supported_openai_params: + optional_params[param] = value + return optional_params + class OpenAITextCompletionConfig: """ diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3caf0c277a..8ab3805e8e 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1,4 +1,4 @@ -import sys, os +import sys, os, json import traceback from dotenv import load_dotenv @@ -1054,6 +1054,25 @@ def test_completion_azure_gpt4_vision(): # test_completion_azure_gpt4_vision() +@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "gpt-4", "gpt-4o"]) +def test_completion_openai_params(model): + litellm.drop_params = True + messages = [ + { + "role": "user", + "content": """Generate JSON about Bill Gates: { "full_name": "", "title": "" }""", + } + ] + + response = completion( + model=model, + messages=messages, + response_format={"type": "json_object"}, + ) + + print(f"response: {response}") + + def test_completion_fireworks_ai(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index 6d0231e8f2..ac246fca6a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5195,7 +5195,7 @@ def get_optional_params( if unsupported_params and not litellm.drop_params: raise UnsupportedParamsError( status_code=500, - message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n", + message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n", ) def _map_and_modify_arg(supported_params: dict, provider: str, model: str): @@ -5884,12 +5884,21 @@ def get_optional_params( optional_params["extra_body"] = ( extra_body # openai client supports `extra_body` param ) - else: # assume passing in params for openai/azure openai - + elif custom_llm_provider == "openai": supported_params = get_supported_openai_params( model=model, custom_llm_provider="openai" ) _check_valid_arg(supported_params=supported_params) + optional_params = litellm.OpenAIConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) + else: # assume passing in params for azure openai + supported_params = get_supported_openai_params( + model=model, custom_llm_provider="azure" + ) + _check_valid_arg(supported_params=supported_params) if functions is not None: optional_params["functions"] = functions if function_call is not None: @@ -6263,7 +6272,9 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "presence_penalty", "stop", ] - elif custom_llm_provider == "openai" or custom_llm_provider == "azure": + elif custom_llm_provider == "openai": + return litellm.OpenAIConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "azure": return [ "functions", "function_call",