fix(vertex_httpx.py): fix supported vertex params

This commit is contained in:
Krrish Dholakia 2024-07-04 21:17:35 -07:00
parent 6d2b429176
commit d0862697b8
4 changed files with 8 additions and 6 deletions

View file

@ -155,6 +155,7 @@ class VertexAIConfig:
"response_format", "response_format",
"n", "n",
"stop", "stop",
"extra_headers",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -400,7 +401,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
## APPEND TOOL CALL MESSAGES ## ## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool": if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
_part = convert_to_gemini_tool_call_result(messages[msg_i], last_message_with_tool_calls) _part = convert_to_gemini_tool_call_result(
messages[msg_i], last_message_with_tool_calls
)
contents.append(ContentType(parts=[_part])) # type: ignore contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1 msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops if msg_i == init_msg_i: # prevent infinite loops

View file

@ -1037,9 +1037,7 @@ class VertexLLM(BaseLLM):
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None "safety_settings", None
) # type: ignore ) # type: ignore
cached_content: Optional[str] = optional_params.pop( cached_content: Optional[str] = optional_params.pop("cached_content", None)
"cached_content", None
)
generation_config: Optional[GenerationConfig] = GenerationConfig( generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params **optional_params
) )

View file

@ -985,6 +985,7 @@ def completion(
mock_delay=kwargs.get("mock_delay", None), mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"

View file

@ -1113,7 +1113,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
extra_headers={"hello": "world"}, extra_headers={"hello": "world"},
) )
except Exception as e: except Exception as e:
pass print("Receives error - {}\n{}".format(str(e), traceback.format_exc()))
mock_call.assert_called_once() mock_call.assert_called_once()