diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index d8607b4a8a..2fef2233c0 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -41,7 +41,7 @@ def get_supports_system_message( from typing import Literal, Optional -all_gemini_url_modes = Literal["chat", "embedding"] +all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"] def _get_vertex_url( @@ -101,4 +101,10 @@ def _get_gemini_url( url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( _gemini_model_name, endpoint, gemini_api_key ) + elif mode == "batch_embedding": + endpoint = "batchEmbedContents" + url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( + _gemini_model_name, endpoint, gemini_api_key + ) + return url, endpoint diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py new file mode 100644 index 0000000000..9535c5594b --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_handler.py @@ -0,0 +1,167 @@ +""" +Google AI Studio /batchEmbedContents Embeddings Endpoint +""" + +import json +from typing import List, Literal, Optional, Union + +import httpx + +from litellm import EmbeddingResponse +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ( + VertexAIBatchEmbeddingsRequestBody, + VertexAIBatchEmbeddingsResponseObject, +) + +from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM +from .batch_embed_content_transformation import ( + process_response, + transform_openai_input_gemini_content, +) + + +class GoogleBatchEmbeddings(VertexLLM): + def batch_embeddings( + self, + model: str, + input: List[str], + print_verbose, + model_response: EmbeddingResponse, + custom_llm_provider: Literal["gemini", "vertex_ai"], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + logging_obj=None, + encoding=None, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=False, + timeout=300, + client=None, + ) -> EmbeddingResponse: + + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=api_key, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=None, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=False, + mode="batch_embedding", + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + optional_params = optional_params or {} + + ### TRANSFORMATION ### + request_data = transform_openai_input_gemini_content( + input=input, model=model, optional_params=optional_params + ) + + headers = { + "Content-Type": "application/json; charset=utf-8", + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) + + if aembedding is True: + return self.async_batch_embeddings( # type: ignore + model=model, + api_base=api_base, + url=url, + data=request_data, + model_response=model_response, + timeout=timeout, + headers=headers, + input=input, + ) + + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore + + return process_response( + model=model, + model_response=model_response, + _predictions=_predictions, + input=input, + ) + + async def async_batch_embeddings( + self, + model: str, + api_base: Optional[str], + url: str, + data: VertexAIBatchEmbeddingsRequestBody, + model_response: EmbeddingResponse, + input: EmbeddingInput, + timeout: Optional[Union[float, httpx.Timeout]], + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> EmbeddingResponse: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore + else: + async_handler = client # type: ignore + + response = await async_handler.post( + url=url, + headers=headers, + data=json.dumps(data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + _json_response = response.json() + _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore + + return process_response( + model=model, + model_response=model_response, + _predictions=_predictions, + input=input, + ) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py new file mode 100644 index 0000000000..e17c79991e --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/batch_embed_content_transformation.py @@ -0,0 +1,68 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List + +from litellm import EmbeddingResponse +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ( + ContentType, + EmbedContentRequest, + PartType, + VertexAIBatchEmbeddingsRequestBody, + VertexAIBatchEmbeddingsResponseObject, +) +from litellm.types.utils import Embedding, Usage +from litellm.utils import get_formatted_prompt, token_counter + +from ..common_utils import VertexAIError + + +def transform_openai_input_gemini_content( + input: List[str], model: str, optional_params: dict +) -> VertexAIBatchEmbeddingsRequestBody: + """ + The content to embed. Only the parts.text fields will be counted. + """ + gemini_model_name = "models/{}".format(model) + requests: List[EmbedContentRequest] = [] + for i in input: + request = EmbedContentRequest( + model=gemini_model_name, + content=ContentType(parts=[PartType(text=i)]), + **optional_params + ) + requests.append(request) + + return VertexAIBatchEmbeddingsRequestBody(requests=requests) + + +def process_response( + input: EmbeddingInput, + model_response: EmbeddingResponse, + model: str, + _predictions: VertexAIBatchEmbeddingsResponseObject, +) -> EmbeddingResponse: + + openai_embeddings: List[Embedding] = [] + for embedding in _predictions["embeddings"]: + openai_embedding = Embedding( + embedding=embedding["values"], + index=0, + object="embedding", + ) + openai_embeddings.append(openai_embedding) + + model_response.data = openai_embeddings + model_response.model = model + + input_text = get_formatted_prompt(data={"input": input}, call_type="embedding") + prompt_tokens = token_counter(model=model, text=input_text) + model_response.usage = Usage( + prompt_tokens=prompt_tokens, total_tokens=prompt_tokens + ) + + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/embed_content_handler.py similarity index 94% rename from litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py rename to litellm/llms/vertex_ai_and_google_ai_studio/embeddings/embed_content_handler.py index 2b26d6c04d..7c1d474352 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/embed_content_handler.py @@ -1,5 +1,5 @@ """ -Google AI Studio Embeddings Endpoint +Google AI Studio /embedContent Embeddings Endpoint """ import json @@ -7,7 +7,6 @@ from typing import Literal, Optional, Union import httpx -import litellm from litellm import EmbeddingResponse from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.llms.openai import EmbeddingInput @@ -15,21 +14,19 @@ from litellm.types.llms.vertex_ai import ( VertexAITextEmbeddingsRequestBody, VertexAITextEmbeddingsResponseObject, ) -from litellm.types.utils import Embedding -from litellm.utils import get_formatted_prompt -from .embeddings_transformation import ( +from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM +from .embed_content_transformation import ( process_response, transform_openai_input_gemini_content, ) -from .vertex_and_google_ai_studio_gemini import VertexLLM class GoogleEmbeddings(VertexLLM): def text_embeddings( self, model: str, - input: Union[list, str], + input: str, print_verbose, model_response: EmbeddingResponse, custom_llm_provider: Literal["gemini", "vertex_ai"], diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/embed_content_transformation.py similarity index 69% rename from litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py rename to litellm/llms/vertex_ai_and_google_ai_studio/embeddings/embed_content_transformation.py index 198811578b..bbda553175 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/embeddings/embed_content_transformation.py @@ -4,8 +4,6 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe Why separate file? Make it easy to see how transformation works """ -from typing import List - from litellm import EmbeddingResponse from litellm.types.llms.openai import EmbeddingInput from litellm.types.llms.vertex_ai import ( @@ -19,19 +17,11 @@ from litellm.utils import get_formatted_prompt, token_counter from ..common_utils import VertexAIError -def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType: +def transform_openai_input_gemini_content(input: str) -> ContentType: """ The content to embed. Only the parts.text fields will be counted. """ - if isinstance(input, str): - return ContentType(parts=[PartType(text=input)]) - elif isinstance(input, list) and len(input) == 1: - return ContentType(parts=[PartType(text=input[0])]) - else: - raise VertexAIError( - status_code=422, - message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues", - ) + return ContentType(parts=[PartType(text=input)]) def process_response( diff --git a/litellm/main.py b/litellm/main.py index e9a3d2898b..bf1b0ede8c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -126,7 +126,10 @@ from .llms.vertex_ai_and_google_ai_studio import ( vertex_ai_anthropic, vertex_ai_non_gemini, ) -from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import ( +from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import ( + GoogleBatchEmbeddings, +) +from .llms.vertex_ai_and_google_ai_studio.embeddings.embed_content_handler import ( GoogleEmbeddings, ) from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( @@ -176,6 +179,7 @@ bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() google_embeddings = GoogleEmbeddings() +google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() @@ -3537,21 +3541,38 @@ def embedding( gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key - response = google_embeddings.text_embeddings( # type: ignore - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - aembedding=aembedding, - print_verbose=print_verbose, - custom_llm_provider="gemini", - api_key=gemini_api_key, - ) + if isinstance(input, str): + response = google_embeddings.text_embeddings( # type: ignore + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=aembedding, + print_verbose=print_verbose, + custom_llm_provider="gemini", + api_key=gemini_api_key, + ) + else: + response = google_batch_embeddings.batch_embeddings( # type: ignore + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=aembedding, + print_verbose=print_verbose, + custom_llm_provider="gemini", + api_key=gemini_api_key, + ) elif custom_llm_provider == "vertex_ai": vertex_ai_project = ( diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index a17b22f489..667674b752 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -687,19 +687,22 @@ async def test_triton_embeddings(): @pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize( + "input", ["good morning from litellm", ["good morning from litellm"]] # +) @pytest.mark.asyncio -async def test_gemini_embeddings(sync_mode): +async def test_gemini_embeddings(sync_mode, input): try: litellm.set_verbose = True if sync_mode: response = litellm.embedding( model="gemini/text-embedding-004", - input=["good morning from litellm"], + input=input, ) else: response = await litellm.aembedding( model="gemini/text-embedding-004", - input=["good morning from litellm"], + input=input, ) print(f"response: {response}") diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index bacb4d2252..aeda867979 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -362,3 +362,15 @@ class ContentEmbeddings(TypedDict): class VertexAITextEmbeddingsResponseObject(TypedDict): embedding: ContentEmbeddings + + +class EmbedContentRequest(VertexAITextEmbeddingsRequestBody): + model: Required[str] + + +class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False): + requests: List[EmbedContentRequest] + + +class VertexAIBatchEmbeddingsResponseObject(TypedDict): + embeddings: List[ContentEmbeddings]