diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py index bc0d4ac16..98cebfc31 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_handler.py @@ -10,7 +10,14 @@ import httpx import litellm from litellm import EmbeddingResponse from litellm.llms.custom_httpx.http_handler import HTTPHandler +from litellm.types.llms.vertex_ai import ( + VertexAITextEmbeddingsRequestBody, + VertexAITextEmbeddingsResponseObject, +) +from litellm.types.utils import Embedding +from litellm.utils import get_formatted_prompt +from .embeddings_transformation import transform_openai_input_gemini_content from .vertex_and_google_ai_studio_gemini import VertexLLM @@ -34,7 +41,7 @@ class GoogleEmbeddings(VertexLLM): timeout=300, client=None, ) -> EmbeddingResponse: - return model_response + auth_header, url = self._get_token_and_url( model=model, gemini_api_key=api_key, @@ -63,59 +70,58 @@ class GoogleEmbeddings(VertexLLM): optional_params = optional_params or {} - # request_data = VertexMultimodalEmbeddingRequest() + ### TRANSFORMATION ### + content = transform_openai_input_gemini_content(input=input) - # if "instances" in optional_params: - # request_data["instances"] = optional_params["instances"] - # elif isinstance(input, list): - # request_data["instances"] = input - # else: - # # construct instances - # vertex_request_instance = Instance(**optional_params) + request_data: VertexAITextEmbeddingsRequestBody = { + "content": content, + **optional_params, + } - # if isinstance(input, str): - # vertex_request_instance["text"] = input + headers = { + "Content-Type": "application/json; charset=utf-8", + } - # request_data["instances"] = [vertex_request_instance] + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": request_data, + "api_base": url, + "headers": headers, + }, + ) - # headers = { - # "Content-Type": "application/json; charset=utf-8", - # "Authorization": f"Bearer {auth_header}", - # } + if aembedding is True: + pass - # ## LOGGING - # logging_obj.pre_call( - # input=input, - # api_key="", - # additional_args={ - # "complete_input_dict": request_data, - # "api_base": url, - # "headers": headers, - # }, - # ) + response = sync_handler.post( + url=url, + headers=headers, + data=json.dumps(request_data), + ) - # if aembedding is True: - # pass + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") - # response = sync_handler.post( - # url=url, - # headers=headers, - # data=json.dumps(request_data), - # ) + _json_response = response.json() + _predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore - # if response.status_code != 200: - # raise Exception(f"Error: {response.status_code} {response.text}") + model_response.data = [ + Embedding( + embedding=_predictions["embedding"]["values"], + index=0, + object="embedding", + ) + ] - # _json_response = response.json() - # if "predictions" not in _json_response: - # raise litellm.InternalServerError( - # message=f"embedding response does not contain 'predictions', got {_json_response}", - # llm_provider="vertex_ai", - # model=model, - # ) - # _predictions = _json_response["predictions"] + model_response.model = model - # model_response.data = _predictions - # model_response.model = model + input_text = get_formatted_prompt(data={"input": input}, call_type="embedding") + prompt_tokens = litellm.token_counter(model=model, text=input_text) + model_response.usage = litellm.Usage( + prompt_tokens=prompt_tokens, total_tokens=prompt_tokens + ) - # return model_response + return model_response diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py index 2e3d156f5..dd5abfa38 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/embeddings_transformation.py @@ -3,3 +3,25 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe Why separate file? Make it easy to see how transformation works """ + +from typing import List + +from litellm.types.llms.openai import EmbeddingInput +from litellm.types.llms.vertex_ai import ContentType, PartType + +from ..common_utils import VertexAIError + + +def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType: + """ + The content to embed. Only the parts.text fields will be counted. + """ + if isinstance(input, str): + return ContentType(parts=[PartType(text=input)]) + elif isinstance(input, list) and len(input) == 1: + return ContentType(parts=[PartType(text=input[0])]) + else: + raise VertexAIError( + status_code=422, + message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues", + ) diff --git a/litellm/main.py b/litellm/main.py index 8896f1faf..e9a3d2898 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio import ( vertex_ai_anthropic, vertex_ai_non_gemini, ) +from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import ( + GoogleEmbeddings, +) from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +google_embeddings = GoogleEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() @@ -3533,7 +3537,7 @@ def embedding( gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key - response = vertex_chat_completion.multimodal_embedding( # type: ignore + response = google_embeddings.text_embeddings( # type: ignore model=model, input=input, encoding=encoding, diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 2fbc70f02..c318264d4 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -697,7 +697,8 @@ async def test_gemini_embeddings(): print(f"response: {response}") # stubbed endpoint is setup to return this - assert response.data[0]["embedding"] == [0.1, 0.2] + assert isinstance(response.data[0]["embedding"], list) + assert response.usage.prompt_tokens > 0 except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 470f72c5b..138441a7e 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.run import Run from openai.types.chat import ChatCompletionChunk +from openai.types.embedding import Embedding as OpenAIEmbedding from pydantic import BaseModel, Field from typing_extensions import Dict, Required, TypedDict, override @@ -47,6 +48,9 @@ FileTypes = Union[ ] +EmbeddingInput = Union[str, List[str]] + + class NotGiven: """ A sentinel singleton class used to distinguish omitted keyword arguments diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 90730d75f..bacb4d225 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -336,3 +336,29 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False): class VertexAICachedContentResponseObject(TypedDict): name: str model: str + + +class TaskTypeEnum(Enum): + TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED" + RETRIEVAL_QUERY = "RETRIEVAL_QUERY" + RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT" + SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY" + CLASSIFICATION = "CLASSIFICATION" + CLUSTERING = "CLUSTERING" + QUESTION_ANSWERING = "QUESTION_ANSWERING" + FACT_VERIFICATION = "FACT_VERIFICATION" + + +class VertexAITextEmbeddingsRequestBody(TypedDict, total=False): + content: Required[ContentType] + taskType: TaskTypeEnum + title: str + outputDimensionality: int + + +class ContentEmbeddings(TypedDict): + values: List[int] + + +class VertexAITextEmbeddingsResponseObject(TypedDict): + embedding: ContentEmbeddings