diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index d121288a2c..8db4b6e85e 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -155,6 +155,7 @@ class VertexAIConfig: "response_format", "n", "stop", + "extra_headers", ] def map_openai_params(self, non_default_params: dict, optional_params: dict): @@ -400,7 +401,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: ## APPEND TOOL CALL MESSAGES ## if msg_i < len(messages) and messages[msg_i]["role"] == "tool": - _part = convert_to_gemini_tool_call_result(messages[msg_i], last_message_with_tool_calls) + _part = convert_to_gemini_tool_call_result( + messages[msg_i], last_message_with_tool_calls + ) contents.append(ContentType(parts=[_part])) # type: ignore msg_i += 1 if msg_i == init_msg_i: # prevent infinite loops diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 62b9085771..31910b7ea3 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -976,7 +976,7 @@ class VertexLLM(BaseLLM): api_base: Optional[str] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore - + auth_header, url = self._get_token_and_url( model=model, gemini_api_key=gemini_api_key, @@ -1037,9 +1037,7 @@ class VertexLLM(BaseLLM): safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( "safety_settings", None ) # type: ignore - cached_content: Optional[str] = optional_params.pop( - "cached_content", None - ) + cached_content: Optional[str] = optional_params.pop("cached_content", None) generation_config: Optional[GenerationConfig] = GenerationConfig( **optional_params ) diff --git a/litellm/main.py b/litellm/main.py index ecd03f1b61..2a7759e8a5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -985,6 +985,7 @@ def completion( mock_delay=kwargs.get("mock_delay", None), custom_llm_provider=custom_llm_provider, ) + if custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 5adbc0d7b5..9c11a42484 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -1113,7 +1113,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider): extra_headers={"hello": "world"}, ) except Exception as e: - pass + print("Receives error - {}\n{}".format(str(e), traceback.format_exc())) mock_call.assert_called_once()