diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index afdc598b44..31ac792d8e 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1436,6 +1436,43 @@ def test_hf_test_completion_tgi(): # hf_test_completion_tgi() + +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.asyncio +async def test_openai_compatible_custom_api_base(provider): + litellm.set_verbose = True + messages = [ + { + "role": "user", + "content": "Hello world", + } + ] + from openai import OpenAI + + openai_client = OpenAI(api_key="fake-key") + + with patch.object( + openai_client.chat.completions, "create", new=MagicMock() + ) as mock_call: + try: + response = completion( + model="openai/my-vllm-model", + messages=messages, + response_format={"type": "json_object"}, + client=openai_client, + api_base="my-custom-api-base", + hello="world", + ) + except Exception as e: + pass + + mock_call.assert_called_once() + + print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) + + assert "hello" in mock_call.call_args.kwargs["extra_body"] + + # ################### Hugging Face Conversational models ######################## # def hf_test_completion_conv(): # try: diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index 32c969ac72..c6bbf71f22 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -4189,12 +4189,12 @@ def test_completion_vllm(): with patch.object(client.completions, "create", side_effect=mock_post) as mock_call: response = text_completion( - model="openai/gemini-1.5-flash", - prompt="ping", - client=client, + model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world" ) print(response) assert response.usage.prompt_tokens == 2 mock_call.assert_called_once() + + assert "hello" in mock_call.call_args.kwargs["extra_body"] diff --git a/litellm/utils.py b/litellm/utils.py index 02b7bfd48f..eb62204f55 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3265,7 +3265,11 @@ def get_optional_params( optional_params["top_logprobs"] = top_logprobs if extra_headers is not None: optional_params["extra_headers"] = extra_headers - if custom_llm_provider in ["openai", "azure"] + litellm.openai_compatible_providers: + if ( + custom_llm_provider + in ["openai", "azure", "text-completion-openai"] + + litellm.openai_compatible_providers + ): # for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46 extra_body = passed_params.pop("extra_body", {}) for k in passed_params.keys():