mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embeddings endpoint
Closes https://github.com/BerriAI/litellm/issues/5385
This commit is contained in:
parent
bd3057e495
commit
d29a7087f1
5 changed files with 110 additions and 40 deletions
|
@ -1,4 +1,4 @@
|
||||||
from typing import Literal
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -37,3 +37,62 @@ def get_supports_system_message(
|
||||||
supports_system_message = False
|
supports_system_message = False
|
||||||
|
|
||||||
return supports_system_message
|
return supports_system_message
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
all_gemini_url_modes = Literal["chat", "embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_vertex_url(
|
||||||
|
mode: all_gemini_url_modes,
|
||||||
|
model: str,
|
||||||
|
stream: Optional[bool],
|
||||||
|
vertex_project: Optional[str],
|
||||||
|
vertex_location: Optional[str],
|
||||||
|
vertex_api_version: Literal["v1", "v1beta1"],
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
if mode == "chat":
|
||||||
|
### SET RUNTIME ENDPOINT ###
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||||
|
else:
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||||
|
|
||||||
|
# if model is only numeric chars then it's a fine tuned gemini model
|
||||||
|
# model = 4965075652664360960
|
||||||
|
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||||
|
if model.isdigit():
|
||||||
|
# It's a fine-tuned Gemini model
|
||||||
|
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||||
|
if stream is True:
|
||||||
|
url += "?alt=sse"
|
||||||
|
|
||||||
|
return url, endpoint
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gemini_url(
|
||||||
|
mode: all_gemini_url_modes,
|
||||||
|
model: str,
|
||||||
|
stream: Optional[bool],
|
||||||
|
gemini_api_key: Optional[str],
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
if mode == "chat":
|
||||||
|
_gemini_model_name = "models/{}".format(model)
|
||||||
|
endpoint = "generateContent"
|
||||||
|
if stream is True:
|
||||||
|
endpoint = "streamGenerateContent"
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = (
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif mode == "embedding":
|
||||||
|
pass
|
||||||
|
return url, endpoint
|
||||||
|
|
|
@ -54,10 +54,16 @@ from litellm.types.llms.vertex_ai import (
|
||||||
from litellm.types.utils import GenericStreamingChunk
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from ..base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from .common_utils import VertexAIError, get_supports_system_message
|
from ..common_utils import (
|
||||||
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
VertexAIError,
|
||||||
from .gemini_transformation import transform_system_message
|
_get_gemini_url,
|
||||||
|
_get_vertex_url,
|
||||||
|
all_gemini_url_modes,
|
||||||
|
get_supports_system_message,
|
||||||
|
)
|
||||||
|
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
||||||
|
from .transformation import transform_system_message
|
||||||
|
|
||||||
context_caching_endpoints = ContextCachingEndpoints()
|
context_caching_endpoints = ContextCachingEndpoints()
|
||||||
|
|
||||||
|
@ -309,6 +315,7 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
||||||
"n",
|
"n",
|
||||||
"stop",
|
"stop",
|
||||||
]
|
]
|
||||||
|
|
||||||
def _map_function(self, value: List[dict]) -> List[Tools]:
|
def _map_function(self, value: List[dict]) -> List[Tools]:
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
googleSearchRetrieval: Optional[dict] = None
|
googleSearchRetrieval: Optional[dict] = None
|
||||||
|
@ -1164,6 +1171,7 @@ class VertexLLM(BaseLLM):
|
||||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
should_use_v1beta1_features: Optional[bool] = False,
|
should_use_v1beta1_features: Optional[bool] = False,
|
||||||
|
mode: all_gemini_url_modes = "chat",
|
||||||
) -> Tuple[Optional[str], str]:
|
) -> Tuple[Optional[str], str]:
|
||||||
"""
|
"""
|
||||||
Internal function. Returns the token and url for the call.
|
Internal function. Returns the token and url for the call.
|
||||||
|
@ -1174,18 +1182,13 @@ class VertexLLM(BaseLLM):
|
||||||
token, url
|
token, url
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "gemini":
|
if custom_llm_provider == "gemini":
|
||||||
_gemini_model_name = "models/{}".format(model)
|
|
||||||
auth_header = None
|
auth_header = None
|
||||||
endpoint = "generateContent"
|
url, endpoint = _get_gemini_url(
|
||||||
if stream is True:
|
mode=mode,
|
||||||
endpoint = "streamGenerateContent"
|
model=model,
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
|
stream=stream,
|
||||||
_gemini_model_name, endpoint, gemini_api_key
|
gemini_api_key=gemini_api_key,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
|
||||||
_gemini_model_name, endpoint, gemini_api_key
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
auth_header, vertex_project = self._ensure_access_token(
|
auth_header, vertex_project = self._ensure_access_token(
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
credentials=vertex_credentials, project_id=vertex_project
|
||||||
|
@ -1193,23 +1196,17 @@ class VertexLLM(BaseLLM):
|
||||||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||||
|
|
||||||
### SET RUNTIME ENDPOINT ###
|
### SET RUNTIME ENDPOINT ###
|
||||||
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
|
version: Literal["v1beta1", "v1"] = (
|
||||||
endpoint = "generateContent"
|
"v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||||
litellm.utils.print_verbose("vertex_project - {}".format(vertex_project))
|
)
|
||||||
if stream is True:
|
url, endpoint = _get_vertex_url(
|
||||||
endpoint = "streamGenerateContent"
|
mode=mode,
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
model=model,
|
||||||
else:
|
stream=stream,
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
# if model is only numeric chars then it's a fine tuned gemini model
|
vertex_api_version=version,
|
||||||
# model = 4965075652664360960
|
)
|
||||||
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
|
||||||
if model.isdigit():
|
|
||||||
# It's a fine-tuned Gemini model
|
|
||||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
|
||||||
if stream is True:
|
|
||||||
url += "?alt=sse"
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
api_base is not None
|
api_base is not None
|
||||||
|
@ -1793,8 +1790,10 @@ class VertexLLM(BaseLLM):
|
||||||
input: Union[list, str],
|
input: Union[list, str],
|
||||||
print_verbose,
|
print_verbose,
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
|
@ -1804,6 +1803,17 @@ class VertexLLM(BaseLLM):
|
||||||
timeout=300,
|
timeout=300,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
|
auth_header, url = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=api_key,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=None,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=False,
|
||||||
|
)
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
_params = {}
|
_params = {}
|
|
@ -126,12 +126,12 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
||||||
vertex_ai_anthropic,
|
vertex_ai_anthropic,
|
||||||
vertex_ai_non_gemini,
|
vertex_ai_non_gemini,
|
||||||
)
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||||
VertexAIPartnerModels,
|
VertexAIPartnerModels,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
|
||||||
VertexLLM,
|
|
||||||
)
|
|
||||||
from .llms.watsonx import IBMWatsonXAI
|
from .llms.watsonx import IBMWatsonXAI
|
||||||
from .types.llms.openai import HttpxBinaryResponseContent
|
from .types.llms.openai import HttpxBinaryResponseContent
|
||||||
from .types.utils import (
|
from .types.utils import (
|
||||||
|
@ -3568,6 +3568,7 @@ def embedding(
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = vertex_ai_non_gemini.embedding(
|
response = vertex_ai_non_gemini.embedding(
|
||||||
|
|
|
@ -28,7 +28,7 @@ from litellm import (
|
||||||
completion_cost,
|
completion_cost,
|
||||||
embedding,
|
embedding,
|
||||||
)
|
)
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
_gemini_convert_messages_with_history,
|
_gemini_convert_messages_with_history,
|
||||||
)
|
)
|
||||||
from litellm.tests.test_streaming import streaming_format_tests
|
from litellm.tests.test_streaming import streaming_format_tests
|
||||||
|
@ -2065,7 +2065,7 @@ def test_prompt_factory_nested():
|
||||||
|
|
||||||
|
|
||||||
def test_get_token_url():
|
def test_get_token_url():
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2087,7 +2087,7 @@ def test_get_token_url():
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
gemini_api_key="",
|
gemini_api_key="",
|
||||||
custom_llm_provider="vertex_ai_beta",
|
custom_llm_provider="vertex_ai_beta",
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
|
||||||
api_base=None,
|
api_base=None,
|
||||||
model="",
|
model="",
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -2107,7 +2107,7 @@ def test_get_token_url():
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
gemini_api_key="",
|
gemini_api_key="",
|
||||||
custom_llm_provider="vertex_ai_beta",
|
custom_llm_provider="vertex_ai_beta",
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
|
||||||
api_base=None,
|
api_base=None,
|
||||||
model="",
|
model="",
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue