diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 26f3b814b..b433301ef 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -123,6 +123,67 @@ class VertexAIConfig: and v is not None } + def get_supported_openai_params(self): + return [ + "temperature", + "top_p", + "max_tokens", + "stream", + "tools", + "tool_choice", + "response_format", + "n", + "stop", + "extra_headers", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if ( + param == "stream" and value == True + ): # sending stream = False, can cause it to get passed unchecked and raise issues + optional_params["stream"] = value + if param == "n": + optional_params["candidate_count"] = value + if param == "stop": + if isinstance(value, str): + optional_params["stop_sequences"] = [value] + elif isinstance(value, list): + optional_params["stop_sequences"] = value + if param == "max_tokens": + optional_params["max_output_tokens"] = value + if param == "response_format" and value["type"] == "json_object": + optional_params["response_mime_type"] = "application/json" + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "tools" and isinstance(value, list): + from vertexai.preview import generative_models + + gtool_func_declarations = [] + for tool in value: + gtool_func_declaration = generative_models.FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"].get("description", ""), + parameters=tool["function"].get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + optional_params["tools"] = [ + generative_models.Tool( + function_declarations=gtool_func_declarations + ) + ] + if param == "tool_choice" and ( + isinstance(value, str) or isinstance(value, dict) + ): + pass + return optional_params + def get_mapped_special_auth_params(self) -> dict: """ Common auth params across bedrock/vertex_ai/azure/watsonx diff --git a/litellm/utils.py b/litellm/utils.py index 1f0aa4c71..0f52df63c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4216,7 +4216,7 @@ def get_supported_openai_params( if model.startswith("meta/"): return litellm.VertexAILlama3Config().get_supported_openai_params() - return litellm.VertexAIConfig().get_supported_openai_params() + return litellm.VertexGeminiConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "vertex_ai_beta":