fix(embeddings_handler.py): initial working commit for google ai studio text embeddings /embedContent endpoint

This commit is contained in:
Krrish Dholakia 2024-08-27 18:14:56 -07:00
parent 77e6da78a1
commit 5b29ddd2a6
6 changed files with 111 additions and 48 deletions

View file

@ -10,7 +10,14 @@ import httpx
import litellm
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.llms.vertex_ai import (
VertexAITextEmbeddingsRequestBody,
VertexAITextEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding
from litellm.utils import get_formatted_prompt
from .embeddings_transformation import transform_openai_input_gemini_content
from .vertex_and_google_ai_studio_gemini import VertexLLM
@ -34,7 +41,7 @@ class GoogleEmbeddings(VertexLLM):
timeout=300,
client=None,
) -> EmbeddingResponse:
return model_response
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
@ -63,59 +70,58 @@ class GoogleEmbeddings(VertexLLM):
optional_params = optional_params or {}
# request_data = VertexMultimodalEmbeddingRequest()
### TRANSFORMATION ###
content = transform_openai_input_gemini_content(input=input)
# if "instances" in optional_params:
# request_data["instances"] = optional_params["instances"]
# elif isinstance(input, list):
# request_data["instances"] = input
# else:
# # construct instances
# vertex_request_instance = Instance(**optional_params)
request_data: VertexAITextEmbeddingsRequestBody = {
"content": content,
**optional_params,
}
# if isinstance(input, str):
# vertex_request_instance["text"] = input
headers = {
"Content-Type": "application/json; charset=utf-8",
}
# request_data["instances"] = [vertex_request_instance]
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
# headers = {
# "Content-Type": "application/json; charset=utf-8",
# "Authorization": f"Bearer {auth_header}",
# }
if aembedding is True:
pass
# ## LOGGING
# logging_obj.pre_call(
# input=input,
# api_key="",
# additional_args={
# "complete_input_dict": request_data,
# "api_base": url,
# "headers": headers,
# },
# )
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
# if aembedding is True:
# pass
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
# response = sync_handler.post(
# url=url,
# headers=headers,
# data=json.dumps(request_data),
# )
_json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
# if response.status_code != 200:
# raise Exception(f"Error: {response.status_code} {response.text}")
model_response.data = [
Embedding(
embedding=_predictions["embedding"]["values"],
index=0,
object="embedding",
)
]
# _json_response = response.json()
# if "predictions" not in _json_response:
# raise litellm.InternalServerError(
# message=f"embedding response does not contain 'predictions', got {_json_response}",
# llm_provider="vertex_ai",
# model=model,
# )
# _predictions = _json_response["predictions"]
model_response.model = model
# model_response.data = _predictions
# model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = litellm.token_counter(model=model, text=input_text)
model_response.usage = litellm.Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
# return model_response
return model_response

View file

@ -3,3 +3,25 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe
Why separate file? Make it easy to see how transformation works
"""
from typing import List
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import ContentType, PartType
from ..common_utils import VertexAIError
def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
"""
The content to embed. Only the parts.text fields will be counted.
"""
if isinstance(input, str):
return ContentType(parts=[PartType(text=input)])
elif isinstance(input, list) and len(input) == 1:
return ContentType(parts=[PartType(text=input[0])])
else:
raise VertexAIError(
status_code=422,
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
)

View file

@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import (
GoogleEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
google_embeddings = GoogleEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
@ -3533,7 +3537,7 @@ def embedding(
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
response = vertex_chat_completion.multimodal_embedding( # type: ignore
response = google_embeddings.text_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,

View file

@ -697,7 +697,8 @@ async def test_gemini_embeddings():
print(f"response: {response}")
# stubbed endpoint is setup to return this
assert response.data[0]["embedding"] == [0.1, 0.2]
assert isinstance(response.data[0]["embedding"], list)
assert response.usage.prompt_tokens > 0
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.run import Run
from openai.types.chat import ChatCompletionChunk
from openai.types.embedding import Embedding as OpenAIEmbedding
from pydantic import BaseModel, Field
from typing_extensions import Dict, Required, TypedDict, override
@ -47,6 +48,9 @@ FileTypes = Union[
]
EmbeddingInput = Union[str, List[str]]
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments

View file

@ -336,3 +336,29 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
class VertexAICachedContentResponseObject(TypedDict):
name: str
model: str
class TaskTypeEnum(Enum):
TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED"
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
CLASSIFICATION = "CLASSIFICATION"
CLUSTERING = "CLUSTERING"
QUESTION_ANSWERING = "QUESTION_ANSWERING"
FACT_VERIFICATION = "FACT_VERIFICATION"
class VertexAITextEmbeddingsRequestBody(TypedDict, total=False):
content: Required[ContentType]
taskType: TaskTypeEnum
title: str
outputDimensionality: int
class ContentEmbeddings(TypedDict):
values: List[int]
class VertexAITextEmbeddingsResponseObject(TypedDict):
embedding: ContentEmbeddings