diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 05ea17c7c..ad7cffd60 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -775,6 +775,8 @@ class VertexLLM(BaseLLM): } if auth_header is not None: headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) ## LOGGING logging_obj.pre_call( diff --git a/litellm/main.py b/litellm/main.py index bd2652817..b0fe59e62 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1931,6 +1931,7 @@ def completion( custom_llm_provider=custom_llm_provider, client=client, api_base=api_base, + extra_headers=extra_headers, ) elif custom_llm_provider == "vertex_ai": diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index f76432447..867486b40 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -877,6 +877,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider): response_format={"type": "json_object"}, client=client, api_base="my-custom-api-base", + extra_headers={"hello": "world"}, ) except Exception as e: pass @@ -884,6 +885,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider): mock_call.assert_called_once() assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"] + assert "hello" in mock_call.call_args.kwargs["headers"] @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")