fix(embeddings_handler.py): initial working commit for google ai studio text embeddings /embedContent endpoint

This commit is contained in:
Krrish Dholakia 2024-08-27 18:14:56 -07:00
parent 77e6da78a1
commit 5b29ddd2a6
6 changed files with 111 additions and 48 deletions

View file

@ -10,7 +10,14 @@ import httpx
import litellm import litellm
from litellm import EmbeddingResponse from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.llms.vertex_ai import (
VertexAITextEmbeddingsRequestBody,
VertexAITextEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding
from litellm.utils import get_formatted_prompt
from .embeddings_transformation import transform_openai_input_gemini_content
from .vertex_and_google_ai_studio_gemini import VertexLLM from .vertex_and_google_ai_studio_gemini import VertexLLM
@ -34,7 +41,7 @@ class GoogleEmbeddings(VertexLLM):
timeout=300, timeout=300,
client=None, client=None,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
return model_response
auth_header, url = self._get_token_and_url( auth_header, url = self._get_token_and_url(
model=model, model=model,
gemini_api_key=api_key, gemini_api_key=api_key,
@ -63,59 +70,58 @@ class GoogleEmbeddings(VertexLLM):
optional_params = optional_params or {} optional_params = optional_params or {}
# request_data = VertexMultimodalEmbeddingRequest() ### TRANSFORMATION ###
content = transform_openai_input_gemini_content(input=input)
# if "instances" in optional_params: request_data: VertexAITextEmbeddingsRequestBody = {
# request_data["instances"] = optional_params["instances"] "content": content,
# elif isinstance(input, list): **optional_params,
# request_data["instances"] = input }
# else:
# # construct instances
# vertex_request_instance = Instance(**optional_params)
# if isinstance(input, str): headers = {
# vertex_request_instance["text"] = input "Content-Type": "application/json; charset=utf-8",
}
# request_data["instances"] = [vertex_request_instance] ## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
# headers = { if aembedding is True:
# "Content-Type": "application/json; charset=utf-8", pass
# "Authorization": f"Bearer {auth_header}",
# }
# ## LOGGING response = sync_handler.post(
# logging_obj.pre_call( url=url,
# input=input, headers=headers,
# api_key="", data=json.dumps(request_data),
# additional_args={ )
# "complete_input_dict": request_data,
# "api_base": url,
# "headers": headers,
# },
# )
# if aembedding is True: if response.status_code != 200:
# pass raise Exception(f"Error: {response.status_code} {response.text}")
# response = sync_handler.post( _json_response = response.json()
# url=url, _predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
# headers=headers,
# data=json.dumps(request_data),
# )
# if response.status_code != 200: model_response.data = [
# raise Exception(f"Error: {response.status_code} {response.text}") Embedding(
embedding=_predictions["embedding"]["values"],
index=0,
object="embedding",
)
]
# _json_response = response.json() model_response.model = model
# if "predictions" not in _json_response:
# raise litellm.InternalServerError(
# message=f"embedding response does not contain 'predictions', got {_json_response}",
# llm_provider="vertex_ai",
# model=model,
# )
# _predictions = _json_response["predictions"]
# model_response.data = _predictions input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
# model_response.model = model prompt_tokens = litellm.token_counter(model=model, text=input_text)
model_response.usage = litellm.Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
# return model_response return model_response

View file

@ -3,3 +3,25 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """
from typing import List
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import ContentType, PartType
from ..common_utils import VertexAIError
def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
"""
The content to embed. Only the parts.text fields will be counted.
"""
if isinstance(input, str):
return ContentType(parts=[PartType(text=input)])
elif isinstance(input, list) and len(input) == 1:
return ContentType(parts=[PartType(text=input[0])])
else:
raise VertexAIError(
status_code=422,
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
)

View file

@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic, vertex_ai_anthropic,
vertex_ai_non_gemini, vertex_ai_non_gemini,
) )
from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import (
GoogleEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
google_embeddings = GoogleEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI() vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI() watsonxai = IBMWatsonXAI()
@ -3533,7 +3537,7 @@ def embedding(
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
response = vertex_chat_completion.multimodal_embedding( # type: ignore response = google_embeddings.text_embeddings( # type: ignore
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding,

View file

@ -697,7 +697,8 @@ async def test_gemini_embeddings():
print(f"response: {response}") print(f"response: {response}")
# stubbed endpoint is setup to return this # stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2] assert isinstance(response.data[0]["embedding"], list)
assert response.usage.prompt_tokens > 0
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.run import Run from openai.types.beta.threads.run import Run
from openai.types.chat import ChatCompletionChunk from openai.types.chat import ChatCompletionChunk
from openai.types.embedding import Embedding as OpenAIEmbedding
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Dict, Required, TypedDict, override from typing_extensions import Dict, Required, TypedDict, override
@ -47,6 +48,9 @@ FileTypes = Union[
] ]
EmbeddingInput = Union[str, List[str]]
class NotGiven: class NotGiven:
""" """
A sentinel singleton class used to distinguish omitted keyword arguments A sentinel singleton class used to distinguish omitted keyword arguments

View file

@ -336,3 +336,29 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
class VertexAICachedContentResponseObject(TypedDict): class VertexAICachedContentResponseObject(TypedDict):
name: str name: str
model: str model: str
class TaskTypeEnum(Enum):
TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED"
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
class VertexAITextEmbeddingsRequestBody(TypedDict, total=False):
content: Required[ContentType]
taskType: TaskTypeEnum
title: str
outputDimensionality: int
class ContentEmbeddings(TypedDict):
values: List[int]
class VertexAITextEmbeddingsResponseObject(TypedDict):
embedding: ContentEmbeddings