diff --git a/litellm/__init__.py b/litellm/__init__.py index 581db4fcb8..27b4fc4408 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -848,7 +848,7 @@ from .llms.gemini import GeminiConfig from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig -from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, GoogleAIStudioGeminiConfig, VertexAIConfig, @@ -865,6 +865,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transf from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import ( VertexAIAi21Config, ) + from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index e87a9bf3c4..618cf510af 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -8,7 +8,7 @@ from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparamet from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) from litellm.types.llms.openai import FineTuningJobCreate diff --git a/litellm/llms/text_to_speech/vertex_ai.py b/litellm/llms/text_to_speech/vertex_ai.py index b9fca53250..0aac32eb50 100644 --- a/litellm/llms/text_to_speech/vertex_ai.py +++ b/litellm/llms/text_to_speech/vertex_ai.py @@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, ) from litellm.llms.openai import HttpxBinaryResponseContent -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 8faf7a3afa..2fef2233c0 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Tuple import httpx @@ -37,3 +37,74 @@ def get_supports_system_message( supports_system_message = False return supports_system_message + + +from typing import Literal, Optional + +all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"] + + +def _get_vertex_url( + mode: all_gemini_url_modes, + model: str, + stream: Optional[bool], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_api_version: Literal["v1", "v1beta1"], +) -> Tuple[str, str]: + if mode == "chat": + ### SET RUNTIME ENDPOINT ### + endpoint = "generateContent" + if stream is True: + endpoint = "streamGenerateContent" + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" + else: + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + + # if model is only numeric chars then it's a fine tuned gemini model + # model = 4965075652664360960 + # send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" + if model.isdigit(): + # It's a fine-tuned Gemini model + url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" + if stream is True: + url += "?alt=sse" + elif mode == "embedding": + endpoint = "predict" + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + + return url, endpoint + + +def _get_gemini_url( + mode: all_gemini_url_modes, + model: str, + stream: Optional[bool], + gemini_api_key: Optional[str], +) -> Tuple[str, str]: + _gemini_model_name = "models/{}".format(model) + if mode == "chat": + endpoint = "generateContent" + if stream is True: + endpoint = "streamGenerateContent" + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( + _gemini_model_name, endpoint, gemini_api_key + ) + else: + url = ( + "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) + ) + elif mode == "embedding": + endpoint = "embedContent" + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) + elif mode == "batch_embedding": + endpoint = "batchEmbedContents" + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) + + return url, endpoint diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py index 944ae00bc9..d394cafd3a 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/context_caching/transformation.py @@ -11,8 +11,10 @@ from litellm.types.llms.vertex_ai import CachedContentRequestBody, SystemInstruc from litellm.utils import is_cached_message from ..common_utils import VertexAIError, get_supports_system_message -from ..gemini_transformation import transform_system_message -from ..vertex_and_google_ai_studio_gemini import _gemini_convert_messages_with_history +from ..gemini.transformation import transform_system_message +from ..gemini.vertex_and_google_ai_studio_gemini import ( + _gemini_convert_messages_with_history, +) def separate_cached_messages( diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py new file mode 100644 index 0000000000..d05688deea --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py @@ -0,0 +1,167 @@ +""" +Google AI Studio /batchEmbedContents Embeddings Endpoint +""" + +import json +from typing import List, Literal, Optional, Union + +import httpx + +from litellm import EmbeddingResponse +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ( + VertexAIBatchEmbeddingsRequestBody, + VertexAIBatchEmbeddingsResponseObject, +) + +from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM +from .batch_embed_content_transformation import ( + process_response, + transform_openai_input_gemini_content, +) + + +class GoogleBatchEmbeddings(VertexLLM): + def batch_embeddings( + self, + model: str, + input: EmbeddingInput, + print_verbose, + model_response: EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + logging_obj=None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ) -> EmbeddingResponse: + + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="batch_embedding", + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + optional_params = optional_params or {} + + ### TRANSFORMATION ### + request_data = transform_openai_input_gemini_content( + input=input, model=model, optional_params=optional_params + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + + if aembedding is True: + return self.async_batch_embeddings( # type: ignore + model=model, + api_base=api_base, + url=url, + data=request_data, + model_response=model_response, + timeout=timeout, + headers=headers, + input=input, + ) + + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore + + return process_response( + model=model, + model_response=model_response, + _predictions=_predictions, + input=input, + ) + + async def async_batch_embeddings( + self, + model: str, + api_base: Optional[str], + url: str, + data: VertexAIBatchEmbeddingsRequestBody, + model_response: EmbeddingResponse, + input: EmbeddingInput, + timeout: Optional[Union[float, httpx.Timeout]], + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> EmbeddingResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore + else: + async_handler = client # type: ignore + + response = await async_handler.post( + url=url, + headers=headers, + data=json.dumps(data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore + + return process_response( + model=model, + model_response=model_response, + _predictions=_predictions, + input=input, + ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py new file mode 100644 index 0000000000..f1785e58f1 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py @@ -0,0 +1,76 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List + +from litellm import EmbeddingResponse +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ( + ContentType, + EmbedContentRequest, + PartType, + VertexAIBatchEmbeddingsRequestBody, + VertexAIBatchEmbeddingsResponseObject, +) +from litellm.types.utils import Embedding, Usage +from litellm.utils import get_formatted_prompt, token_counter + +from ..common_utils import VertexAIError + + +def transform_openai_input_gemini_content( + input: EmbeddingInput, model: str, optional_params: dict +) -> VertexAIBatchEmbeddingsRequestBody: + """ + The content to embed. Only the parts.text fields will be counted. + """ + gemini_model_name = "models/{}".format(model) + requests: List[EmbedContentRequest] = [] + if isinstance(input, str): + request = EmbedContentRequest( + model=gemini_model_name, + content=ContentType(parts=[PartType(text=input)]), + **optional_params + ) + requests.append(request) + else: + for i in input: + request = EmbedContentRequest( + model=gemini_model_name, + content=ContentType(parts=[PartType(text=i)]), + **optional_params + ) + requests.append(request) + + return VertexAIBatchEmbeddingsRequestBody(requests=requests) + + +def process_response( + input: EmbeddingInput, + model_response: EmbeddingResponse, + model: str, + _predictions: VertexAIBatchEmbeddingsResponseObject, +) -> EmbeddingResponse: + + openai_embeddings: List[Embedding] = [] + for embedding in _predictions["embeddings"]: + openai_embedding = Embedding( + embedding=embedding["values"], + index=0, + object="embedding", + ) + openai_embeddings.append(openai_embedding) + + model_response.data = openai_embeddings + model_response.model = model + + input_text = get_formatted_prompt(data={"input": input}, call_type="embedding") + prompt_tokens = token_counter(model=model, text=input_text) + model_response.usage = Usage( + prompt_tokens=prompt_tokens, total_tokens=prompt_tokens + ) + + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py similarity index 100% rename from litellm/llms/vertex_ai_and_google_ai_studio/gemini_transformation.py rename to litellm/llms/vertex_ai_and_google_ai_studio/gemini/transformation.py diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py similarity index 96% rename from litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py rename to litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 5392f253f8..819a94cb0c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -54,10 +54,16 @@ from litellm.types.llms.vertex_ai import ( from litellm.types.utils import GenericStreamingChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage -from ..base import BaseLLM -from .common_utils import VertexAIError, get_supports_system_message -from .context_caching.vertex_ai_context_caching import ContextCachingEndpoints -from .gemini_transformation import transform_system_message +from ...base import BaseLLM +from ..common_utils import ( + VertexAIError, + _get_gemini_url, + _get_vertex_url, + all_gemini_url_modes, + get_supports_system_message, +) +from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints +from .transformation import transform_system_message context_caching_endpoints = ContextCachingEndpoints() @@ -309,6 +315,7 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty "n", "stop", ] + def _map_function(self, value: List[dict]) -> List[Tools]: gtool_func_declarations = [] googleSearchRetrieval: Optional[dict] = None @@ -1164,6 +1171,7 @@ class VertexLLM(BaseLLM): custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], api_base: Optional[str], should_use_v1beta1_features: Optional[bool] = False, + mode: all_gemini_url_modes = "chat", ) -> Tuple[Optional[str], str]: """ Internal function. Returns the token and url for the call. @@ -1174,18 +1182,13 @@ class VertexLLM(BaseLLM): token, url """ if custom_llm_provider == "gemini": - _gemini_model_name = "models/{}".format(model) auth_header = None - endpoint = "generateContent" - if stream is True: - endpoint = "streamGenerateContent" - url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format( - _gemini_model_name, endpoint, gemini_api_key - ) - else: - url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( - _gemini_model_name, endpoint, gemini_api_key - ) + url, endpoint = _get_gemini_url( + mode=mode, + model=model, + stream=stream, + gemini_api_key=gemini_api_key, + ) else: auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, project_id=vertex_project @@ -1193,23 +1196,17 @@ class VertexLLM(BaseLLM): vertex_location = self.get_vertex_region(vertex_region=vertex_location) ### SET RUNTIME ENDPOINT ### - version = "v1beta1" if should_use_v1beta1_features is True else "v1" - endpoint = "generateContent" - litellm.utils.print_verbose("vertex_project - {}".format(vertex_project)) - if stream is True: - endpoint = "streamGenerateContent" - url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse" - else: - url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" - - # if model is only numeric chars then it's a fine tuned gemini model - # model = 4965075652664360960 - # send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" - if model.isdigit(): - # It's a fine-tuned Gemini model - url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}" - if stream is True: - url += "?alt=sse" + version: Literal["v1beta1", "v1"] = ( + "v1beta1" if should_use_v1beta1_features is True else "v1" + ) + url, endpoint = _get_vertex_url( + mode=mode, + model=model, + stream=stream, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version=version, + ) if ( api_base is not None @@ -1793,8 +1790,10 @@ class VertexLLM(BaseLLM): input: Union[list, str], print_verbose, model_response: litellm.EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], optional_params: dict, api_key: Optional[str] = None, + api_base: Optional[str] = None, logging_obj=None, encoding=None, vertex_project=None, @@ -1804,6 +1803,18 @@ class VertexLLM(BaseLLM): timeout=300, client=None, ): + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="embedding", + ) if client is None: _params = {} @@ -1818,11 +1829,6 @@ class VertexLLM(BaseLLM): else: sync_handler = client # type: ignore - url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" - - auth_header, _ = self._ensure_access_token( - credentials=vertex_credentials, project_id=vertex_project - ) optional_params = optional_params or {} request_data = VertexMultimodalEmbeddingRequest() @@ -1840,30 +1846,22 @@ class VertexLLM(BaseLLM): request_data["instances"] = [vertex_request_instance] - request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" - logging_obj.pre_call( - input=[], - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - logging_obj.pre_call( - input=[], - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {auth_header}", } + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + if aembedding is True: return self.async_multimodal_embedding( model=model, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py index b13b87bc67..e85160a43c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py @@ -205,7 +205,7 @@ def get_vertex_client( vertex_credentials: Optional[str], ) -> Tuple[Any, Optional[str]]: args = locals() - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -270,7 +270,7 @@ def completion( from anthropic import AnthropicVertex from litellm.llms.anthropic import AnthropicChatCompletion - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) except: diff --git a/litellm/main.py b/litellm/main.py index 7eac01ff6f..f20cfa9966 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -126,12 +126,15 @@ from .llms.vertex_ai_and_google_ai_studio import ( vertex_ai_anthropic, vertex_ai_non_gemini, ) +from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import ( + GoogleBatchEmbeddings, +) +from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( VertexAIPartnerModels, ) -from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( - VertexLLM, -) from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent from .types.utils import ( @@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() @@ -3134,6 +3138,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "gemini" or custom_llm_provider == "databricks" or custom_llm_provider == "watsonx" or custom_llm_provider == "cohere" @@ -3531,6 +3536,26 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "gemini": + + gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key + + response = google_batch_embeddings.batch_embeddings( # type: ignore + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=aembedding, + print_verbose=print_verbose, + custom_llm_provider="gemini", + api_key=gemini_api_key, + ) + elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( optional_params.pop("vertex_project", None) @@ -3571,6 +3596,7 @@ def embedding( vertex_credentials=vertex_credentials, aembedding=aembedding, print_verbose=print_verbose, + custom_llm_provider="vertex_ai", ) else: response = vertex_ai_non_gemini.embedding( diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index ac18a65f28..fa33bab3b6 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -28,7 +28,7 @@ from litellm import ( completion_cost, embedding, ) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( _gemini_convert_messages_with_history, ) from litellm.tests.test_streaming import streaming_format_tests @@ -2085,7 +2085,7 @@ def test_prompt_factory_nested(): def test_get_token_url(): - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import ( + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -2107,7 +2107,7 @@ def test_get_token_url(): vertex_credentials=vertex_credentials, gemini_api_key="", custom_llm_provider="vertex_ai_beta", - should_use_v1beta1_features=should_use_v1beta1_features, + should_use_vertex_v1beta1_features=should_use_v1beta1_features, api_base=None, model="", stream=False, @@ -2127,7 +2127,7 @@ def test_get_token_url(): vertex_credentials=vertex_credentials, gemini_api_key="", custom_llm_provider="vertex_ai_beta", - should_use_v1beta1_features=should_use_v1beta1_features, + should_use_vertex_v1beta1_features=should_use_v1beta1_features, api_base=None, model="", stream=False, diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 54c823e4dc..ec85a782dd 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -695,6 +695,33 @@ async def test_triton_embeddings(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize( + "input", ["good morning from litellm", ["good morning from litellm"]] # +) +@pytest.mark.asyncio +async def test_gemini_embeddings(sync_mode, input): + try: + litellm.set_verbose = True + if sync_mode: + response = litellm.embedding( + model="gemini/text-embedding-004", + input=input, + ) + else: + response = await litellm.aembedding( + model="gemini/text-embedding-004", + input=input, + ) + print(f"response: {response}") + + # stubbed endpoint is setup to return this + assert isinstance(response.data[0]["embedding"], list) + assert response.usage.prompt_tokens > 0 + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_databricks_embeddings(sync_mode): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 470f72c5b6..138441a7eb 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.run import Run from openai.types.chat import ChatCompletionChunk +from openai.types.embedding import Embedding as OpenAIEmbedding from pydantic import BaseModel, Field from typing_extensions import Dict, Required, TypedDict, override @@ -47,6 +48,9 @@ FileTypes = Union[ ] +EmbeddingInput = Union[str, List[str]] + + class NotGiven: """ A sentinel singleton class used to distinguish omitted keyword arguments diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 90730d75fe..aeda867979 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -336,3 +336,41 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False): class VertexAICachedContentResponseObject(TypedDict): name: str model: str + + +class TaskTypeEnum(Enum): + TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED" + RETRIEVAL_QUERY = "RETRIEVAL_QUERY" + RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" + SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" + CLASSIFICATION = "CLASSIFICATION" + CLUSTERING = "CLUSTERING" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + FACT_VERIFICATION = "FACT_VERIFICATION" + + +class VertexAITextEmbeddingsRequestBody(TypedDict, total=False): + content: Required[ContentType] + taskType: TaskTypeEnum + title: str + outputDimensionality: int + + +class ContentEmbeddings(TypedDict): + values: List[int] + + +class VertexAITextEmbeddingsResponseObject(TypedDict): + embedding: ContentEmbeddings + + +class EmbedContentRequest(VertexAITextEmbeddingsRequestBody): + model: Required[str] + + +class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False): + requests: List[EmbedContentRequest] + + +class VertexAIBatchEmbeddingsResponseObject(TypedDict): + embeddings: List[ContentEmbeddings]