mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
refactor(main.py): migrate vertex gemini calls to vertex_httpx
Completes migration to vertex_httpx
This commit is contained in:
parent
e835f7336a
commit
86596c53e9
6 changed files with 159 additions and 206 deletions
|
@ -800,8 +800,12 @@ from .llms.gemini import GeminiConfig
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
from .llms.nlp_cloud import NLPCloudConfig
|
||||||
from .llms.aleph_alpha import AlephAlphaConfig
|
from .llms.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
from .llms.petals import PetalsConfig
|
||||||
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
|
from .llms.vertex_httpx import (
|
||||||
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
VertexGeminiConfig,
|
||||||
|
GoogleAIStudioGeminiConfig,
|
||||||
|
VertexAIConfig,
|
||||||
|
)
|
||||||
|
from .llms.vertex_ai import VertexAITextEmbeddingConfig
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||||
from .llms.sagemaker import SagemakerConfig
|
from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
|
|
|
@ -42,201 +42,6 @@ class VertexAIError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class ExtendedGenerationConfig(dict):
|
|
||||||
"""Extended parameters for the generation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
candidate_count: Optional[int] = None,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
response_mime_type: Optional[str] = None,
|
|
||||||
frequency_penalty: Optional[float] = None,
|
|
||||||
presence_penalty: Optional[float] = None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
response_mime_type=response_mime_type,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIConfig:
|
|
||||||
"""
|
|
||||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
|
||||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
|
||||||
|
|
||||||
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
|
||||||
|
|
||||||
- `temperature` (float): This controls the degree of randomness in token selection.
|
|
||||||
|
|
||||||
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
|
||||||
|
|
||||||
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
|
||||||
|
|
||||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
|
||||||
|
|
||||||
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
|
||||||
|
|
||||||
- `candidate_count` (int): Number of generated responses to return.
|
|
||||||
|
|
||||||
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
|
||||||
|
|
||||||
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
|
||||||
|
|
||||||
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
|
||||||
|
|
||||||
Note: Please make sure to modify the default parameters as required for your use case.
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
max_output_tokens: Optional[int] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
response_mime_type: Optional[str] = None
|
|
||||||
candidate_count: Optional[int] = None
|
|
||||||
stop_sequences: Optional[list] = None
|
|
||||||
frequency_penalty: Optional[float] = None
|
|
||||||
presence_penalty: Optional[float] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
response_mime_type: Optional[str] = None,
|
|
||||||
candidate_count: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[list] = None,
|
|
||||||
frequency_penalty: Optional[float] = None,
|
|
||||||
presence_penalty: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"max_tokens",
|
|
||||||
"stream",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"response_format",
|
|
||||||
"n",
|
|
||||||
"stop",
|
|
||||||
"extra_headers",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if (
|
|
||||||
param == "stream" and value == True
|
|
||||||
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "n":
|
|
||||||
optional_params["candidate_count"] = value
|
|
||||||
if param == "stop":
|
|
||||||
if isinstance(value, str):
|
|
||||||
optional_params["stop_sequences"] = [value]
|
|
||||||
elif isinstance(value, list):
|
|
||||||
optional_params["stop_sequences"] = value
|
|
||||||
if param == "max_tokens":
|
|
||||||
optional_params["max_output_tokens"] = value
|
|
||||||
if param == "response_format" and value["type"] == "json_object":
|
|
||||||
optional_params["response_mime_type"] = "application/json"
|
|
||||||
if param == "frequency_penalty":
|
|
||||||
optional_params["frequency_penalty"] = value
|
|
||||||
if param == "presence_penalty":
|
|
||||||
optional_params["presence_penalty"] = value
|
|
||||||
if param == "tools" and isinstance(value, list):
|
|
||||||
from vertexai.preview import generative_models
|
|
||||||
|
|
||||||
gtool_func_declarations = []
|
|
||||||
for tool in value:
|
|
||||||
gtool_func_declaration = generative_models.FunctionDeclaration(
|
|
||||||
name=tool["function"]["name"],
|
|
||||||
description=tool["function"].get("description", ""),
|
|
||||||
parameters=tool["function"].get("parameters", {}),
|
|
||||||
)
|
|
||||||
gtool_func_declarations.append(gtool_func_declaration)
|
|
||||||
optional_params["tools"] = [
|
|
||||||
generative_models.Tool(
|
|
||||||
function_declarations=gtool_func_declarations
|
|
||||||
)
|
|
||||||
]
|
|
||||||
if param == "tool_choice" and (
|
|
||||||
isinstance(value, str) or isinstance(value, dict)
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def get_mapped_special_auth_params(self) -> dict:
|
|
||||||
"""
|
|
||||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
|
||||||
"""
|
|
||||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
|
||||||
|
|
||||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
mapped_params = self.get_mapped_special_auth_params()
|
|
||||||
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param in mapped_params:
|
|
||||||
optional_params[mapped_params[param]] = value
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
def get_eu_regions(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
"europe-central2",
|
|
||||||
"europe-north1",
|
|
||||||
"europe-southwest1",
|
|
||||||
"europe-west1",
|
|
||||||
"europe-west2",
|
|
||||||
"europe-west3",
|
|
||||||
"europe-west4",
|
|
||||||
"europe-west6",
|
|
||||||
"europe-west8",
|
|
||||||
"europe-west9",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@ -445,6 +250,14 @@ def completion(
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
acompletion: bool = False,
|
acompletion: bool = False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
NON-GEMINI/ANTHROPIC CALLS.
|
||||||
|
|
||||||
|
This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN
|
||||||
|
|
||||||
|
For Vertex AI Anthropic: `vertex_anthropic.py`
|
||||||
|
For Gemini: `vertex_httpx.py`
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import vertexai
|
import vertexai
|
||||||
except:
|
except:
|
||||||
|
|
|
@ -50,6 +50,111 @@ from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
|
||||||
|
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `temperature` (float): This controls the degree of randomness in token selection.
|
||||||
|
|
||||||
|
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
||||||
|
|
||||||
|
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
||||||
|
|
||||||
|
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||||
|
|
||||||
|
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||||
|
|
||||||
|
- `candidate_count` (int): Number of generated responses to return.
|
||||||
|
|
||||||
|
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||||
|
|
||||||
|
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
||||||
|
|
||||||
|
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
||||||
|
|
||||||
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
response_mime_type: Optional[str] = None
|
||||||
|
candidate_count: Optional[int] = None
|
||||||
|
stop_sequences: Optional[list] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[list] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_eu_regions(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"europe-central2",
|
||||||
|
"europe-north1",
|
||||||
|
"europe-southwest1",
|
||||||
|
"europe-west1",
|
||||||
|
"europe-west2",
|
||||||
|
"europe-west3",
|
||||||
|
"europe-west4",
|
||||||
|
"europe-west6",
|
||||||
|
"europe-west8",
|
||||||
|
"europe-west9",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
|
class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
|
||||||
"""
|
"""
|
||||||
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
||||||
|
@ -326,6 +431,7 @@ class VertexGeminiConfig:
|
||||||
"stop",
|
"stop",
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
|
|
||||||
def map_tool_choice_values(
|
def map_tool_choice_values(
|
||||||
|
@ -691,7 +797,9 @@ class VertexLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
tools.append(_tool_response_chunk)
|
tools.append(_tool_response_chunk)
|
||||||
|
|
||||||
chat_completion_message["content"] = content_str
|
chat_completion_message["content"] = (
|
||||||
|
content_str if len(content_str) > 0 else None
|
||||||
|
)
|
||||||
chat_completion_message["tool_calls"] = tools
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
choice = litellm.Choices(
|
choice = litellm.Choices(
|
||||||
|
|
|
@ -2080,6 +2080,28 @@ def completion(
|
||||||
headers=headers,
|
headers=headers,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
)
|
)
|
||||||
|
elif "gemini" in model:
|
||||||
|
model_response = vertex_chat_completion.completion( # type: ignore
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=new_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
vertex_location=vertex_ai_location,
|
||||||
|
vertex_project=vertex_ai_project,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
gemini_api_key=None,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
timeout=timeout,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_response = vertex_ai.completion(
|
model_response = vertex_ai.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2099,8 +2121,8 @@ def completion(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] is True
|
||||||
and acompletion == False
|
and acompletion is False
|
||||||
):
|
):
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response,
|
model_response,
|
||||||
|
|
|
@ -501,7 +501,7 @@ async def test_async_vertexai_streaming_response():
|
||||||
user_message = "Hello, how are you?"
|
user_message = "Hello, how are you?"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
response = await acompletion(
|
response = await acompletion(
|
||||||
model="gemini-pro",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
|
@ -1311,6 +1311,7 @@ async def test_gemini_pro_async_function_calling():
|
||||||
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
model="gemini-pro", messages=messages, tools=tools, tool_choice="auto"
|
||||||
)
|
)
|
||||||
print(f"completion: {completion}")
|
print(f"completion: {completion}")
|
||||||
|
print(f"message content: {completion.choices[0].message.content}")
|
||||||
assert completion.choices[0].message.content is None
|
assert completion.choices[0].message.content is None
|
||||||
assert len(completion.choices[0].message.tool_calls) == 1
|
assert len(completion.choices[0].message.tool_calls) == 1
|
||||||
|
|
||||||
|
|
|
@ -2824,7 +2824,6 @@ def get_optional_params(
|
||||||
or model in litellm.vertex_text_models
|
or model in litellm.vertex_text_models
|
||||||
or model in litellm.vertex_code_text_models
|
or model in litellm.vertex_code_text_models
|
||||||
or model in litellm.vertex_language_models
|
or model in litellm.vertex_language_models
|
||||||
or model in litellm.vertex_embedding_models
|
|
||||||
or model in litellm.vertex_vision_models
|
or model in litellm.vertex_vision_models
|
||||||
):
|
):
|
||||||
print_verbose(f"(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK")
|
print_verbose(f"(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK")
|
||||||
|
@ -2834,9 +2833,15 @@ def get_optional_params(
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
optional_params = litellm.VertexAIConfig().map_openai_params(
|
optional_params = litellm.VertexGeminiConfig().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -2852,7 +2857,7 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
elif custom_llm_provider == "vertex_ai_beta":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
@ -3936,12 +3941,12 @@ def get_supported_openai_params(
|
||||||
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
return litellm.VertexGeminiConfig().get_supported_openai_params()
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "vertex_ai_beta":
|
elif custom_llm_provider == "vertex_ai_beta":
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
return litellm.VertexGeminiConfig().get_supported_openai_params()
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue