fix(vertex_httpx.py): allow passing extra headers

Closes https://github.com/BerriAI/litellm/pull/4327
This commit is contained in:
Krrish Dholakia 2024-06-20 21:14:22 -07:00
parent 3a510582c2
commit dfd5882c31
3 changed files with 5 additions and 0 deletions

View file

@ -775,6 +775,8 @@ class VertexLLM(BaseLLM):
}
if auth_header is not None:
headers["Authorization"] = f"Bearer {auth_header}"
if extra_headers is not None:
headers.update(extra_headers)
## LOGGING
logging_obj.pre_call(

View file

@ -1931,6 +1931,7 @@ def completion(
custom_llm_provider=custom_llm_provider,
client=client,
api_base=api_base,
extra_headers=extra_headers,
)
elif custom_llm_provider == "vertex_ai":

View file

@ -877,6 +877,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
response_format={"type": "json_object"},
client=client,
api_base="my-custom-api-base",
extra_headers={"hello": "world"},
)
except Exception as e:
pass
@ -884,6 +885,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
mock_call.assert_called_once()
assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"]
assert "hello" in mock_call.call_args.kwargs["headers"]
@pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")