refactor(utils.py): make it clearer how vertex ai params are handled '

'
This commit is contained in:
Krrish Dholakia 2024-04-17 16:20:56 -07:00
parent 409bd5b4ab
commit 32d94feddd
4 changed files with 74 additions and 42 deletions

View file

@ -87,6 +87,60 @@ class VertexAIConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stream":
optional_params["stream"] = value
if param == "n":
optional_params["candidate_count"] = value
if param == "stop":
if isinstance(value, str):
optional_params["stop_sequences"] = [value]
elif isinstance(value, list):
optional_params["stop_sequences"] = value
if param == "max_tokens":
optional_params["max_output_tokens"] = value
if param == "response_format" and value["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
if param == "tools" and isinstance(value, list):
from vertexai.preview import generative_models
gtool_func_declarations = []
for tool in value:
gtool_func_declaration = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
generative_models.Tool(
function_declarations=gtool_func_declarations
)
]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
pass
return optional_params
import asyncio import asyncio

View file

@ -1012,6 +1012,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-1.5-pro-preview-0215": { "gemini-1.5-pro-preview-0215": {
@ -1023,6 +1024,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-1.5-pro-preview-0409": { "gemini-1.5-pro-preview-0409": {
@ -1034,6 +1036,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-experimental": { "gemini-experimental": {
@ -1045,6 +1048,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": false, "supports_function_calling": false,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-pro-vision": { "gemini-pro-vision": {
@ -1271,6 +1275,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini/gemini-1.5-pro-latest": { "gemini/gemini-1.5-pro-latest": {
@ -1283,6 +1288,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true,
"source": "https://ai.google.dev/models/gemini" "source": "https://ai.google.dev/models/gemini"
}, },
"gemini/gemini-pro-vision": { "gemini/gemini-pro-vision": {

View file

@ -4878,37 +4878,11 @@ def get_optional_params(
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if temperature is not None: optional_params = litellm.VertexAIConfig().map_openai_params(
optional_params["temperature"] = temperature non_default_params=non_default_params,
if top_p is not None: optional_params=optional_params,
optional_params["top_p"] = top_p )
if stream:
optional_params["stream"] = stream
if n is not None:
optional_params["candidate_count"] = n
if stop is not None:
if isinstance(stop, str):
optional_params["stop_sequences"] = [stop]
elif isinstance(stop, list):
optional_params["stop_sequences"] = stop
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
if response_format is not None and response_format["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models
gtool_func_declarations = []
for tool in tools:
gtool_func_declaration = generative_models.FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
generative_models.Tool(function_declarations=gtool_func_declarations)
]
print_verbose( print_verbose(
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
) )
@ -5610,17 +5584,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
return [ return litellm.VertexAIConfig().get_supported_openai_params()
"temperature",
"top_p",
"max_tokens",
"stream",
"tools",
"tool_choice",
"response_format",
"n",
"stop",
]
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":
@ -10595,7 +10559,9 @@ def trim_messages(
if max_tokens is None: if max_tokens is None:
# Check if model is valid # Check if model is valid
if model in litellm.model_cost: if model in litellm.model_cost:
max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"]) max_tokens_for_model = litellm.model_cost[model].get(
"max_input_tokens", litellm.model_cost[model]["max_tokens"]
)
max_tokens = int(max_tokens_for_model * trim_ratio) max_tokens = int(max_tokens_for_model * trim_ratio)
else: else:
# if user did not specify max (input) tokens # if user did not specify max (input) tokens

View file

@ -1012,6 +1012,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-1.5-pro-preview-0215": { "gemini-1.5-pro-preview-0215": {
@ -1023,6 +1024,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-1.5-pro-preview-0409": { "gemini-1.5-pro-preview-0409": {
@ -1034,6 +1036,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-experimental": { "gemini-experimental": {
@ -1045,6 +1048,7 @@
"litellm_provider": "vertex_ai-language-models", "litellm_provider": "vertex_ai-language-models",
"mode": "chat", "mode": "chat",
"supports_function_calling": false, "supports_function_calling": false,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini-pro-vision": { "gemini-pro-vision": {
@ -1271,6 +1275,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini/gemini-1.5-pro-latest": { "gemini/gemini-1.5-pro-latest": {
@ -1283,6 +1288,7 @@
"mode": "chat", "mode": "chat",
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true,
"source": "https://ai.google.dev/models/gemini" "source": "https://ai.google.dev/models/gemini"
}, },
"gemini/gemini-pro-vision": { "gemini/gemini-pro-vision": {