mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embeddings endpoint
Closes https://github.com/BerriAI/litellm/issues/5385
This commit is contained in:
parent
bd3057e495
commit
d29a7087f1
5 changed files with 110 additions and 40 deletions
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -37,3 +37,62 @@ def get_supports_system_message(
|
|||
supports_system_message = False
|
||||
|
||||
return supports_system_message
|
||||
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
all_gemini_url_modes = Literal["chat", "embedding"]
|
||||
|
||||
|
||||
def _get_vertex_url(
|
||||
mode: all_gemini_url_modes,
|
||||
model: str,
|
||||
stream: Optional[bool],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_api_version: Literal["v1", "v1beta1"],
|
||||
) -> Tuple[str, str]:
|
||||
if mode == "chat":
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
# if model is only numeric chars then it's a fine tuned gemini model
|
||||
# model = 4965075652664360960
|
||||
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
if model.isdigit():
|
||||
# It's a fine-tuned Gemini model
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
if stream is True:
|
||||
url += "?alt=sse"
|
||||
|
||||
return url, endpoint
|
||||
|
||||
|
||||
def _get_gemini_url(
|
||||
mode: all_gemini_url_modes,
|
||||
model: str,
|
||||
stream: Optional[bool],
|
||||
gemini_api_key: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
if mode == "chat":
|
||||
_gemini_model_name = "models/{}".format(model)
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
else:
|
||||
url = (
|
||||
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
)
|
||||
elif mode == "embedding":
|
||||
pass
|
||||
return url, endpoint
|
||||
|
|
|
@ -54,10 +54,16 @@ from litellm.types.llms.vertex_ai import (
|
|||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from ..base import BaseLLM
|
||||
from .common_utils import VertexAIError, get_supports_system_message
|
||||
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
||||
from .gemini_transformation import transform_system_message
|
||||
from ...base import BaseLLM
|
||||
from ..common_utils import (
|
||||
VertexAIError,
|
||||
_get_gemini_url,
|
||||
_get_vertex_url,
|
||||
all_gemini_url_modes,
|
||||
get_supports_system_message,
|
||||
)
|
||||
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
|
||||
from .transformation import transform_system_message
|
||||
|
||||
context_caching_endpoints = ContextCachingEndpoints()
|
||||
|
||||
|
@ -309,6 +315,7 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
|||
"n",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def _map_function(self, value: List[dict]) -> List[Tools]:
|
||||
gtool_func_declarations = []
|
||||
googleSearchRetrieval: Optional[dict] = None
|
||||
|
@ -1164,6 +1171,7 @@ class VertexLLM(BaseLLM):
|
|||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
should_use_v1beta1_features: Optional[bool] = False,
|
||||
mode: all_gemini_url_modes = "chat",
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
@ -1174,18 +1182,13 @@ class VertexLLM(BaseLLM):
|
|||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
_gemini_model_name = "models/{}".format(model)
|
||||
auth_header = None
|
||||
endpoint = "generateContent"
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
else:
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||
_gemini_model_name, endpoint, gemini_api_key
|
||||
)
|
||||
url, endpoint = _get_gemini_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
gemini_api_key=gemini_api_key,
|
||||
)
|
||||
else:
|
||||
auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
|
@ -1193,23 +1196,17 @@ class VertexLLM(BaseLLM):
|
|||
vertex_location = self.get_vertex_region(vertex_region=vertex_location)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
version = "v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||
endpoint = "generateContent"
|
||||
litellm.utils.print_verbose("vertex_project - {}".format(vertex_project))
|
||||
if stream is True:
|
||||
endpoint = "streamGenerateContent"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
# if model is only numeric chars then it's a fine tuned gemini model
|
||||
# model = 4965075652664360960
|
||||
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
if model.isdigit():
|
||||
# It's a fine-tuned Gemini model
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
|
||||
if stream is True:
|
||||
url += "?alt=sse"
|
||||
version: Literal["v1beta1", "v1"] = (
|
||||
"v1beta1" if should_use_v1beta1_features is True else "v1"
|
||||
)
|
||||
url, endpoint = _get_vertex_url(
|
||||
mode=mode,
|
||||
model=model,
|
||||
stream=stream,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version=version,
|
||||
)
|
||||
|
||||
if (
|
||||
api_base is not None
|
||||
|
@ -1793,8 +1790,10 @@ class VertexLLM(BaseLLM):
|
|||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
|
@ -1804,6 +1803,17 @@ class VertexLLM(BaseLLM):
|
|||
timeout=300,
|
||||
client=None,
|
||||
):
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
|
@ -126,12 +126,12 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
|||
vertex_ai_anthropic,
|
||||
vertex_ai_non_gemini,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||
VertexAIPartnerModels,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.utils import (
|
||||
|
@ -3568,6 +3568,7 @@ def embedding(
|
|||
vertex_credentials=vertex_credentials,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
response = vertex_ai_non_gemini.embedding(
|
||||
|
|
|
@ -28,7 +28,7 @@ from litellm import (
|
|||
completion_cost,
|
||||
embedding,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
_gemini_convert_messages_with_history,
|
||||
)
|
||||
from litellm.tests.test_streaming import streaming_format_tests
|
||||
|
@ -2065,7 +2065,7 @@ def test_prompt_factory_nested():
|
|||
|
||||
|
||||
def test_get_token_url():
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
||||
|
@ -2087,7 +2087,7 @@ def test_get_token_url():
|
|||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key="",
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
|
||||
api_base=None,
|
||||
model="",
|
||||
stream=False,
|
||||
|
@ -2107,7 +2107,7 @@ def test_get_token_url():
|
|||
vertex_credentials=vertex_credentials,
|
||||
gemini_api_key="",
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
should_use_vertex_v1beta1_features=should_use_v1beta1_features,
|
||||
api_base=None,
|
||||
model="",
|
||||
stream=False,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue