fix(vertex_httpx.py): fix supported vertex params

This commit is contained in:
Krrish Dholakia 2024-07-04 21:17:35 -07:00
parent 6d2b429176
commit d0862697b8
4 changed files with 8 additions and 6 deletions

View file

@ -155,6 +155,7 @@ class VertexAIConfig:
"response_format",
"n",
"stop",
"extra_headers",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
@ -400,7 +401,9 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]:
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
_part = convert_to_gemini_tool_call_result(messages[msg_i], last_message_with_tool_calls)
_part = convert_to_gemini_tool_call_result(
messages[msg_i], last_message_with_tool_calls
)
contents.append(ContentType(parts=[_part])) # type: ignore
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops

View file

@ -1037,9 +1037,7 @@ class VertexLLM(BaseLLM):
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
"safety_settings", None
) # type: ignore
cached_content: Optional[str] = optional_params.pop(
"cached_content", None
)
cached_content: Optional[str] = optional_params.pop("cached_content", None)
generation_config: Optional[GenerationConfig] = GenerationConfig(
**optional_params
)

View file

@ -985,6 +985,7 @@ def completion(
mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"

View file

@ -1113,7 +1113,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
extra_headers={"hello": "world"},
)
except Exception as e:
pass
print("Receives error - {}\n{}".format(str(e), traceback.format_exc()))
mock_call.assert_called_once()