feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embeddings endpoint

Closes https://github.com/BerriAI/litellm/issues/5385
This commit is contained in:
Krrish Dholakia 2024-08-27 16:53:11 -07:00
parent bd3057e495
commit d29a7087f1
5 changed files with 110 additions and 40 deletions

View file

@ -1,4 +1,4 @@
from typing import Literal from typing import Literal, Tuple
import httpx import httpx
@ -37,3 +37,62 @@ def get_supports_system_message(
supports_system_message = False supports_system_message = False
return supports_system_message return supports_system_message
from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding"]
def _get_vertex_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_api_version: Literal["v1", "v1beta1"],
) -> Tuple[str, str]:
if mode == "chat":
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
# if model is only numeric chars then it's a fine tuned gemini model
# model = 4965075652664360960
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
return url, endpoint
def _get_gemini_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
gemini_api_key: Optional[str],
) -> Tuple[str, str]:
if mode == "chat":
_gemini_model_name = "models/{}".format(model)
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
elif mode == "embedding":
pass
return url, endpoint

View file

@ -54,10 +54,16 @@ from litellm.types.llms.vertex_ai import (
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..base import BaseLLM from ...base import BaseLLM
from .common_utils import VertexAIError, get_supports_system_message from ..common_utils import (
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints VertexAIError,
from .gemini_transformation import transform_system_message _get_gemini_url,
_get_vertex_url,
all_gemini_url_modes,
get_supports_system_message,
)
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .transformation import transform_system_message
context_caching_endpoints = ContextCachingEndpoints() context_caching_endpoints = ContextCachingEndpoints()
@ -309,6 +315,7 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
"n", "n",
"stop", "stop",
] ]
def _map_function(self, value: List[dict]) -> List[Tools]: def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = [] gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None googleSearchRetrieval: Optional[dict] = None
@ -1164,6 +1171,7 @@ class VertexLLM(BaseLLM):
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str], api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False, should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
) -> Tuple[Optional[str], str]: ) -> Tuple[Optional[str], str]:
""" """
Internal function. Returns the token and url for the call. Internal function. Returns the token and url for the call.
@ -1174,18 +1182,13 @@ class VertexLLM(BaseLLM):
token, url token, url
""" """
if custom_llm_provider == "gemini": if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None auth_header = None
endpoint = "generateContent" url, endpoint = _get_gemini_url(
if stream is True: mode=mode,
endpoint = "streamGenerateContent" model=model,
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( stream=stream,
_gemini_model_name, endpoint, gemini_api_key gemini_api_key=gemini_api_key,
) )
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
else: else:
auth_header, vertex_project = self._ensure_access_token( auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials, project_id=vertex_project
@ -1193,23 +1196,17 @@ class VertexLLM(BaseLLM):
vertex_location = self.get_vertex_region(vertex_region=vertex_location) vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
version = "v1beta1" if should_use_v1beta1_features is True else "v1" version: Literal["v1beta1", "v1"] = (
endpoint = "generateContent" "v1beta1" if should_use_v1beta1_features is True else "v1"
litellm.utils.print_verbose("vertex_project - {}".format(vertex_project)) )
if stream is True: url, endpoint = _get_vertex_url(
endpoint = "streamGenerateContent" mode=mode,
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" model=model,
else: stream=stream,
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" vertex_project=vertex_project,
vertex_location=vertex_location,
# if model is only numeric chars then it's a fine tuned gemini model vertex_api_version=version,
# model = 4965075652664360960 )
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
if ( if (
api_base is not None api_base is not None
@ -1793,8 +1790,10 @@ class VertexLLM(BaseLLM):
input: Union[list, str], input: Union[list, str],
print_verbose, print_verbose,
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None, logging_obj=None,
encoding=None, encoding=None,
vertex_project=None, vertex_project=None,
@ -1804,6 +1803,17 @@ class VertexLLM(BaseLLM):
timeout=300, timeout=300,
client=None, client=None,
): ):
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
)
if client is None: if client is None:
_params = {} _params = {}

View file

@ -126,12 +126,12 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic, vertex_ai_anthropic,
vertex_ai_non_gemini, vertex_ai_non_gemini,
) )
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels, VertexAIPartnerModels,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ( from .types.utils import (
@ -3568,6 +3568,7 @@ def embedding(
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
aembedding=aembedding, aembedding=aembedding,
print_verbose=print_verbose, print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
) )
else: else:
response = vertex_ai_non_gemini.embedding( response = vertex_ai_non_gemini.embedding(

View file

@ -28,7 +28,7 @@ from litellm import (
completion_cost, completion_cost,
embedding, embedding,
) )
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history, _gemini_convert_messages_with_history,
) )
from litellm.tests.test_streaming import streaming_format_tests from litellm.tests.test_streaming import streaming_format_tests
@ -2065,7 +2065,7 @@ def test_prompt_factory_nested():
def test_get_token_url(): def test_get_token_url():
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -2087,7 +2087,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
gemini_api_key="", gemini_api_key="",
custom_llm_provider="vertex_ai_beta", custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features, should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None, api_base=None,
model="", model="",
stream=False, stream=False,
@ -2107,7 +2107,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
gemini_api_key="", gemini_api_key="",
custom_llm_provider="vertex_ai_beta", custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features, should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None, api_base=None,
model="", model="",
stream=False, stream=False,