From 32d94feddd85a579bf673d9e4b800be91d46236c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 17 Apr 2024 16:20:56 -0700 Subject: [PATCH] refactor(utils.py): make it clearer how vertex ai params are handled ' ' --- litellm/llms/vertex_ai.py | 54 +++++++++++++++++++ ...odel_prices_and_context_window_backup.json | 6 +++ litellm/utils.py | 50 +++-------------- model_prices_and_context_window.json | 6 +++ 4 files changed, 74 insertions(+), 42 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 69feef63c..1d9da108b 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -87,6 +87,60 @@ class VertexAIConfig: and v is not None } + def get_supported_openai_params(self): + return [ + "temperature", + "top_p", + "max_tokens", + "stream", + "tools", + "tool_choice", + "response_format", + "n", + "stop", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stream": + optional_params["stream"] = value + if param == "n": + optional_params["candidate_count"] = value + if param == "stop": + if isinstance(value, str): + optional_params["stop_sequences"] = [value] + elif isinstance(value, list): + optional_params["stop_sequences"] = value + if param == "max_tokens": + optional_params["max_output_tokens"] = value + if param == "response_format" and value["type"] == "json_object": + optional_params["response_mime_type"] = "application/json" + if param == "tools" and isinstance(value, list): + from vertexai.preview import generative_models + + gtool_func_declarations = [] + for tool in value: + gtool_func_declaration = generative_models.FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"].get("description", ""), + parameters=tool["function"].get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + optional_params["tools"] = [ + generative_models.Tool( + function_declarations=gtool_func_declarations + ) + ] + if param == "tool_choice" and ( + isinstance(value, str) or isinstance(value, dict) + ): + pass + return optional_params + import asyncio diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index bf5adb430..a88976ef3 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1012,6 +1012,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1023,6 +1024,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1034,6 +1036,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-experimental": { @@ -1045,6 +1048,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": false, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-pro-vision": { @@ -1271,6 +1275,7 @@ "mode": "chat", "supports_function_calling": true, "supports_vision": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -1283,6 +1288,7 @@ "mode": "chat", "supports_function_calling": true, "supports_vision": true, + "supports_tool_choice": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": { diff --git a/litellm/utils.py b/litellm/utils.py index dd538c7d0..1f6c7fe0d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4878,37 +4878,11 @@ def get_optional_params( ) _check_valid_arg(supported_params=supported_params) - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stream: - optional_params["stream"] = stream - if n is not None: - optional_params["candidate_count"] = n - if stop is not None: - if isinstance(stop, str): - optional_params["stop_sequences"] = [stop] - elif isinstance(stop, list): - optional_params["stop_sequences"] = stop - if max_tokens is not None: - optional_params["max_output_tokens"] = max_tokens - if response_format is not None and response_format["type"] == "json_object": - optional_params["response_mime_type"] = "application/json" - if tools is not None and isinstance(tools, list): - from vertexai.preview import generative_models + optional_params = litellm.VertexAIConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) - gtool_func_declarations = [] - for tool in tools: - gtool_func_declaration = generative_models.FunctionDeclaration( - name=tool["function"]["name"], - description=tool["function"].get("description", ""), - parameters=tool["function"].get("parameters", {}), - ) - gtool_func_declarations.append(gtool_func_declaration) - optional_params["tools"] = [ - generative_models.Tool(function_declarations=gtool_func_declarations) - ] print_verbose( f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" ) @@ -5610,17 +5584,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] elif custom_llm_provider == "vertex_ai": - return [ - "temperature", - "top_p", - "max_tokens", - "stream", - "tools", - "tool_choice", - "response_format", - "n", - "stop", - ] + return litellm.VertexAIConfig().get_supported_openai_params() elif custom_llm_provider == "sagemaker": return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] elif custom_llm_provider == "aleph_alpha": @@ -10595,7 +10559,9 @@ def trim_messages( if max_tokens is None: # Check if model is valid if model in litellm.model_cost: - max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"]) + max_tokens_for_model = litellm.model_cost[model].get( + "max_input_tokens", litellm.model_cost[model]["max_tokens"] + ) max_tokens = int(max_tokens_for_model * trim_ratio) else: # if user did not specify max (input) tokens diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index bf5adb430..a88976ef3 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1012,6 +1012,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0215": { @@ -1023,6 +1024,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-1.5-pro-preview-0409": { @@ -1034,6 +1036,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-experimental": { @@ -1045,6 +1048,7 @@ "litellm_provider": "vertex_ai-language-models", "mode": "chat", "supports_function_calling": false, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini-pro-vision": { @@ -1271,6 +1275,7 @@ "mode": "chat", "supports_function_calling": true, "supports_vision": true, + "supports_tool_choice": true, "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" }, "gemini/gemini-1.5-pro-latest": { @@ -1283,6 +1288,7 @@ "mode": "chat", "supports_function_calling": true, "supports_vision": true, + "supports_tool_choice": true, "source": "https://ai.google.dev/models/gemini" }, "gemini/gemini-pro-vision": {