Merge pull request #5393 from BerriAI/litellm_gemini_embedding_support

feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embedding Endpoint
This commit is contained in:
Krish Dholakia 2024-08-28 13:46:28 -07:00 committed by GitHub
commit 996c028127
15 changed files with 481 additions and 71 deletions

View file

@ -848,7 +848,7 @@ from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig, VertexGeminiConfig,
GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig,
VertexAIConfig, VertexAIConfig,
@ -865,6 +865,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transf
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import (
VertexAIAi21Config, VertexAIAi21Config,
) )
from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig from .llms.ollama_chat import OllamaChatConfig

View file

@ -8,7 +8,7 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
from litellm.types.llms.openai import FineTuningJobCreate from litellm.types.llms.openai import FineTuningJobCreate

View file

@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
) )
from litellm.llms.openai import HttpxBinaryResponseContent from litellm.llms.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )

View file

@ -1,4 +1,4 @@
from typing import Literal from typing import Literal, Tuple
import httpx import httpx
@ -37,3 +37,74 @@ def get_supports_system_message(
supports_system_message = False supports_system_message = False
return supports_system_message return supports_system_message
from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
def _get_vertex_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_api_version: Literal["v1", "v1beta1"],
) -> Tuple[str, str]:
if mode == "chat":
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
else:
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
# if model is only numeric chars then it's a fine tuned gemini model
# model = 4965075652664360960
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
elif mode == "embedding":
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
return url, endpoint
def _get_gemini_url(
mode: all_gemini_url_modes,
model: str,
stream: Optional[bool],
gemini_api_key: Optional[str],
) -> Tuple[str, str]:
_gemini_model_name = "models/{}".format(model)
if mode == "chat":
endpoint = "generateContent"
if stream is True:
endpoint = "streamGenerateContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
_gemini_model_name, endpoint, gemini_api_key
)
else:
url = (
"https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
)
elif mode == "embedding":
endpoint = "embedContent"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
elif mode == "batch_embedding":
endpoint = "batchEmbedContents"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
return url, endpoint

View file

@ -11,8 +11,10 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc
from litellm.utils import is_cached_message from litellm.utils import is_cached_message
from ..common_utils import VertexAIError, get_supports_system_message from ..common_utils import VertexAIError, get_supports_system_message
from ..gemini_transformation import transform_system_message from ..gemini.transformation import transform_system_message
from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history from ..gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history,
)
def separate_cached_messages( def separate_cached_messages(

View file

@ -0,0 +1,167 @@
"""
Google AI Studio /batchEmbedContents Embeddings Endpoint
"""
import json
from typing import List, Literal, Optional, Union
import httpx
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .batch_embed_content_transformation import (
process_response,
transform_openai_input_gemini_content,
)
class GoogleBatchEmbeddings(VertexLLM):
def batch_embeddings(
self,
model: str,
input: EmbeddingInput,
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
) -> EmbeddingResponse:
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="batch_embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
### TRANSFORMATION ###
request_data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params
)
headers = {
"Content-Type": "application/json; charset=utf-8",
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_batch_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=request_data,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
async def async_batch_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: VertexAIBatchEmbeddingsRequestBody,
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
response = await async_handler.post(
url=url,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)

View file

@ -0,0 +1,76 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
Why separate file? Make it easy to see how transformation works
"""
from typing import List
from litellm import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
ContentType,
EmbedContentRequest,
PartType,
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, Usage
from litellm.utils import get_formatted_prompt, token_counter
from ..common_utils import VertexAIError
def transform_openai_input_gemini_content(
input: EmbeddingInput, model: str, optional_params: dict
) -> VertexAIBatchEmbeddingsRequestBody:
"""
The content to embed. Only the parts.text fields will be counted.
"""
gemini_model_name = "models/{}".format(model)
requests: List[EmbedContentRequest] = []
if isinstance(input, str):
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=input)]),
**optional_params
)
requests.append(request)
else:
for i in input:
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=i)]),
**optional_params
)
requests.append(request)
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
_predictions: VertexAIBatchEmbeddingsResponseObject,
) -> EmbeddingResponse:
openai_embeddings: List[Embedding] = []
for embedding in _predictions["embeddings"]:
openai_embedding = Embedding(
embedding=embedding["values"],
index=0,
object="embedding",
)
openai_embeddings.append(openai_embedding)
model_response.data = openai_embeddings
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response

View file

@ -54,10 +54,16 @@ from litellm.types.llms.vertex_ai import (
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..base import BaseLLM from ...base import BaseLLM
from .common_utils import VertexAIError, get_supports_system_message from ..common_utils import (
from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints VertexAIError,
from .gemini_transformation import transform_system_message _get_gemini_url,
_get_vertex_url,
all_gemini_url_modes,
get_supports_system_message,
)
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
from .transformation import transform_system_message
context_caching_endpoints = ContextCachingEndpoints() context_caching_endpoints = ContextCachingEndpoints()
@ -309,6 +315,7 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
"n", "n",
"stop", "stop",
] ]
def _map_function(self, value: List[dict]) -> List[Tools]: def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = [] gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None googleSearchRetrieval: Optional[dict] = None
@ -1164,6 +1171,7 @@ class VertexLLM(BaseLLM):
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str], api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False, should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
) -> Tuple[Optional[str], str]: ) -> Tuple[Optional[str], str]:
""" """
Internal function. Returns the token and url for the call. Internal function. Returns the token and url for the call.
@ -1174,18 +1182,13 @@ class VertexLLM(BaseLLM):
token, url token, url
""" """
if custom_llm_provider == "gemini": if custom_llm_provider == "gemini":
_gemini_model_name = "models/{}".format(model)
auth_header = None auth_header = None
endpoint = "generateContent" url, endpoint = _get_gemini_url(
if stream is True: mode=mode,
endpoint = "streamGenerateContent" model=model,
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( stream=stream,
_gemini_model_name, endpoint, gemini_api_key gemini_api_key=gemini_api_key,
) )
else:
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
else: else:
auth_header, vertex_project = self._ensure_access_token( auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project credentials=vertex_credentials, project_id=vertex_project
@ -1193,23 +1196,17 @@ class VertexLLM(BaseLLM):
vertex_location = self.get_vertex_region(vertex_region=vertex_location) vertex_location = self.get_vertex_region(vertex_region=vertex_location)
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
version = "v1beta1" if should_use_v1beta1_features is True else "v1" version: Literal["v1beta1", "v1"] = (
endpoint = "generateContent" "v1beta1" if should_use_v1beta1_features is True else "v1"
litellm.utils.print_verbose("vertex_project - {}".format(vertex_project)) )
if stream is True: url, endpoint = _get_vertex_url(
endpoint = "streamGenerateContent" mode=mode,
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" model=model,
else: stream=stream,
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" vertex_project=vertex_project,
vertex_location=vertex_location,
# if model is only numeric chars then it's a fine tuned gemini model vertex_api_version=version,
# model = 4965075652664360960 )
# send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if model.isdigit():
# It's a fine-tuned Gemini model
url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
if stream is True:
url += "?alt=sse"
if ( if (
api_base is not None api_base is not None
@ -1793,8 +1790,10 @@ class VertexLLM(BaseLLM):
input: Union[list, str], input: Union[list, str],
print_verbose, print_verbose,
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None, logging_obj=None,
encoding=None, encoding=None,
vertex_project=None, vertex_project=None,
@ -1804,6 +1803,18 @@ class VertexLLM(BaseLLM):
timeout=300, timeout=300,
client=None, client=None,
): ):
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None: if client is None:
_params = {} _params = {}
@ -1818,11 +1829,6 @@ class VertexLLM(BaseLLM):
else: else:
sync_handler = client # type: ignore sync_handler = client # type: ignore
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
auth_header, _ = self._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
optional_params = optional_params or {} optional_params = optional_params or {}
request_data = VertexMultimodalEmbeddingRequest() request_data = VertexMultimodalEmbeddingRequest()
@ -1840,30 +1846,22 @@ class VertexLLM(BaseLLM):
request_data["instances"] = [vertex_request_instance] request_data["instances"] = [vertex_request_instance]
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
logging_obj.pre_call(
input=[],
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)
headers = { headers = {
"Content-Type": "application/json; charset=utf-8", "Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}", "Authorization": f"Bearer {auth_header}",
} }
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True: if aembedding is True:
return self.async_multimodal_embedding( return self.async_multimodal_embedding(
model=model, model=model,

View file

@ -205,7 +205,7 @@ def get_vertex_client(
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]: ) -> Tuple[Any, Optional[str]]:
args = locals() args = locals()
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -270,7 +270,7 @@ def completion(
from anthropic import AnthropicVertex from anthropic import AnthropicVertex
from litellm.llms.anthropic import AnthropicChatCompletion from litellm.llms.anthropic import AnthropicChatCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
except: except:

View file

@ -126,12 +126,15 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic, vertex_ai_anthropic,
vertex_ai_non_gemini, vertex_ai_non_gemini,
) )
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels, VertexAIPartnerModels,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ( from .types.utils import (
@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI() vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI() watsonxai = IBMWatsonXAI()
@ -3134,6 +3138,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere" or custom_llm_provider == "cohere"
@ -3531,6 +3536,26 @@ def embedding(
client=client, client=client,
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "gemini":
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
@ -3571,6 +3596,7 @@ def embedding(
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
aembedding=aembedding, aembedding=aembedding,
print_verbose=print_verbose, print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
) )
else: else:
response = vertex_ai_non_gemini.embedding( response = vertex_ai_non_gemini.embedding(

View file

@ -28,7 +28,7 @@ from litellm import (
completion_cost, completion_cost,
embedding, embedding,
) )
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
_gemini_convert_messages_with_history, _gemini_convert_messages_with_history,
) )
from litellm.tests.test_streaming import streaming_format_tests from litellm.tests.test_streaming import streaming_format_tests
@ -2085,7 +2085,7 @@ def test_prompt_factory_nested():
def test_get_token_url(): def test_get_token_url():
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -2107,7 +2107,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
gemini_api_key="", gemini_api_key="",
custom_llm_provider="vertex_ai_beta", custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features, should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None, api_base=None,
model="", model="",
stream=False, stream=False,
@ -2127,7 +2127,7 @@ def test_get_token_url():
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
gemini_api_key="", gemini_api_key="",
custom_llm_provider="vertex_ai_beta", custom_llm_provider="vertex_ai_beta",
should_use_v1beta1_features=should_use_v1beta1_features, should_use_vertex_v1beta1_features=should_use_v1beta1_features,
api_base=None, api_base=None,
model="", model="",
stream=False, stream=False,

View file

@ -695,6 +695,33 @@ async def test_triton_embeddings():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"input", ["good morning from litellm", ["good morning from litellm"]] #
)
@pytest.mark.asyncio
async def test_gemini_embeddings(sync_mode, input):
try:
litellm.set_verbose = True
if sync_mode:
response = litellm.embedding(
model="gemini/text-embedding-004",
input=input,
)
else:
response = await litellm.aembedding(
model="gemini/text-embedding-004",
input=input,
)
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert isinstance(response.data[0]["embedding"], list)
assert response.usage.prompt_tokens > 0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_databricks_embeddings(sync_mode): async def test_databricks_embeddings(sync_mode):

View file

@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.run import Run from openai.types.beta.threads.run import Run
from openai.types.chat import ChatCompletionChunk from openai.types.chat import ChatCompletionChunk
from openai.types.embedding import Embedding as OpenAIEmbedding
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Dict, Required, TypedDict, override from typing_extensions import Dict, Required, TypedDict, override
@ -47,6 +48,9 @@ FileTypes = Union[
] ]
EmbeddingInput = Union[str, List[str]]
class NotGiven: class NotGiven:
""" """
A sentinel singleton class used to distinguish omitted keyword arguments A sentinel singleton class used to distinguish omitted keyword arguments

View file

@ -336,3 +336,41 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
class VertexAICachedContentResponseObject(TypedDict): class VertexAICachedContentResponseObject(TypedDict):
name: str name: str
model: str model: str
class TaskTypeEnum(Enum):
TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED"
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
class VertexAITextEmbeddingsRequestBody(TypedDict, total=False):
content: Required[ContentType]
taskType: TaskTypeEnum
title: str
outputDimensionality: int
class ContentEmbeddings(TypedDict):
values: List[int]
class VertexAITextEmbeddingsResponseObject(TypedDict):
embedding: ContentEmbeddings
class EmbedContentRequest(VertexAITextEmbeddingsRequestBody):
model: Required[str]
class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False):
requests: List[EmbedContentRequest]
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
embeddings: List[ContentEmbeddings]