feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embeddings endpoint

Closes https://github.com/BerriAI/litellm/issues/5385
This commit is contained in:
Krrish Dholakia 2024-08-27 16:53:11 -07:00
parent bd3057e495
commit d29a7087f1
5 changed files with 110 additions and 40 deletions

View file

@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Tuple
import httpx
@ -37,3 +37,62 @@ def get_supports_system_message(
supports_system_message = False
return supports_system_message
from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding"]
def _get_vertex_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_api_version: Literal["v1", "v1beta1"],
) -> Tuple[str, str]:
if mode == "chat":
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
# if model is only numeric chars then it's a fine tuned gemini model
# model = 4965075652664360960
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
return url, endpoint
def _get_gemini_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
gemini_api_key: Optional[str],
) -> Tuple[str, str]:
if mode == "chat":
_gemini_model_name = "models/{}".format(model)
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
elif mode == "embedding":
pass
return url, endpoint

View file

@ -54,10 +54,16 @@ from litellm.types.llms.vertex_ai import (
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..base import BaseLLM
from .common_utils import VertexAIError, get_supports_system_message
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .gemini_transformation import transform_system_message
from ...base import BaseLLM
from ..common_utils import (
VertexAIError,
_get_gemini_url,
_get_vertex_url,
all_gemini_url_modes,
get_supports_system_message,
)
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .transformation import transform_system_message
context_caching_endpoints = ContextCachingEndpoints()
@ -309,6 +315,7 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
"n",
"stop",
]
def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
@ -1164,6 +1171,7 @@ class VertexLLM(BaseLLM):
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
@ -1174,18 +1182,13 @@ class VertexLLM(BaseLLM):
token, url
"""
if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
url, endpoint = _get_gemini_url(
mode=mode,
model=model,
stream=stream,
gemini_api_key=gemini_api_key,
)
else:
auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
@ -1193,23 +1196,17 @@ class VertexLLM(BaseLLM):
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ###
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
endpoint = "generateContent"
litellm.utils.print_verbose("vertex_project - {}".format(vertex_project))
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
# if model is only numeric chars then it's a fine tuned gemini model
# model = 4965075652664360960
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
version: Literal["v1beta1", "v1"] = (
"v1beta1" if should_use_v1beta1_features is True else "v1"
)
url, endpoint = _get_vertex_url(
mode=mode,
model=model,
stream=stream,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
)
if (
api_base is not None
@ -1793,8 +1790,10 @@ class VertexLLM(BaseLLM):
input: Union[list, str],
print_verbose,
model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
@ -1804,6 +1803,17 @@ class VertexLLM(BaseLLM):
timeout=300,
client=None,
):
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
)
if client is None:
_params = {}

View file

@ -126,12 +126,12 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import (
@ -3568,6 +3568,7 @@ def embedding(
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
)
else:
response = vertex_ai_non_gemini.embedding(

View file

@ -28,7 +28,7 @@ from litellm import (
completion_cost,
embedding,
)
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history,
)
from litellm.tests.test_streaming import streaming_format_tests
@ -2065,7 +2065,7 @@ def test_prompt_factory_nested():
def test_get_token_url():
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -2087,7 +2087,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials,
gemini_api_key="",
custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features,
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None,
model="",
stream=False,
@ -2107,7 +2107,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials,
gemini_api_key="",
custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features,
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None,
model="",
stream=False,