forked from phoenix/litellm-mirror
fix(embeddings_handler.py): initial working commit for google ai studio text embeddings /embedContent endpoint
This commit is contained in:
parent
77e6da78a1
commit
5b29ddd2a6
6 changed files with 111 additions and 48 deletions
|
@ -10,7 +10,14 @@ import httpx
|
|||
import litellm
|
||||
from litellm import EmbeddingResponse
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VertexAITextEmbeddingsRequestBody,
|
||||
VertexAITextEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import Embedding
|
||||
from litellm.utils import get_formatted_prompt
|
||||
|
||||
from .embeddings_transformation import transform_openai_input_gemini_content
|
||||
from .vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
|
||||
|
||||
|
@ -34,7 +41,7 @@ class GoogleEmbeddings(VertexLLM):
|
|||
timeout=300,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
return model_response
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=api_key,
|
||||
|
@ -63,59 +70,58 @@ class GoogleEmbeddings(VertexLLM):
|
|||
|
||||
optional_params = optional_params or {}
|
||||
|
||||
# request_data = VertexMultimodalEmbeddingRequest()
|
||||
### TRANSFORMATION ###
|
||||
content = transform_openai_input_gemini_content(input=input)
|
||||
|
||||
# if "instances" in optional_params:
|
||||
# request_data["instances"] = optional_params["instances"]
|
||||
# elif isinstance(input, list):
|
||||
# request_data["instances"] = input
|
||||
# else:
|
||||
# # construct instances
|
||||
# vertex_request_instance = Instance(**optional_params)
|
||||
request_data: VertexAITextEmbeddingsRequestBody = {
|
||||
"content": content,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
# if isinstance(input, str):
|
||||
# vertex_request_instance["text"] = input
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
|
||||
# request_data["instances"] = [vertex_request_instance]
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# headers = {
|
||||
# "Content-Type": "application/json; charset=utf-8",
|
||||
# "Authorization": f"Bearer {auth_header}",
|
||||
# }
|
||||
if aembedding is True:
|
||||
pass
|
||||
|
||||
# ## LOGGING
|
||||
# logging_obj.pre_call(
|
||||
# input=input,
|
||||
# api_key="",
|
||||
# additional_args={
|
||||
# "complete_input_dict": request_data,
|
||||
# "api_base": url,
|
||||
# "headers": headers,
|
||||
# },
|
||||
# )
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
# if aembedding is True:
|
||||
# pass
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
# response = sync_handler.post(
|
||||
# url=url,
|
||||
# headers=headers,
|
||||
# data=json.dumps(request_data),
|
||||
# )
|
||||
_json_response = response.json()
|
||||
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
# if response.status_code != 200:
|
||||
# raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
model_response.data = [
|
||||
Embedding(
|
||||
embedding=_predictions["embedding"]["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
]
|
||||
|
||||
# _json_response = response.json()
|
||||
# if "predictions" not in _json_response:
|
||||
# raise litellm.InternalServerError(
|
||||
# message=f"embedding response does not contain 'predictions', got {_json_response}",
|
||||
# llm_provider="vertex_ai",
|
||||
# model=model,
|
||||
# )
|
||||
# _predictions = _json_response["predictions"]
|
||||
model_response.model = model
|
||||
|
||||
# model_response.data = _predictions
|
||||
# model_response.model = model
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = litellm.token_counter(model=model, text=input_text)
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
# return model_response
|
||||
return model_response
|
||||
|
|
|
@ -3,3 +3,25 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe
|
|||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import ContentType, PartType
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
|
||||
|
||||
def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
|
||||
"""
|
||||
The content to embed. Only the parts.text fields will be counted.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return ContentType(parts=[PartType(text=input)])
|
||||
elif isinstance(input, list) and len(input) == 1:
|
||||
return ContentType(parts=[PartType(text=input[0])])
|
||||
else:
|
||||
raise VertexAIError(
|
||||
status_code=422,
|
||||
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
|
||||
)
|
||||
|
|
|
@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
|||
vertex_ai_anthropic,
|
||||
vertex_ai_non_gemini,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import (
|
||||
GoogleEmbeddings,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion()
|
|||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
google_embeddings = GoogleEmbeddings()
|
||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
|
@ -3533,7 +3537,7 @@ def embedding(
|
|||
|
||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||
|
||||
response = vertex_chat_completion.multimodal_embedding( # type: ignore
|
||||
response = google_embeddings.text_embeddings( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
|
|
|
@ -697,7 +697,8 @@ async def test_gemini_embeddings():
|
|||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
||||
assert isinstance(response.data[0]["embedding"], list)
|
||||
assert response.usage.prompt_tokens > 0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
|
|||
from openai.types.beta.threads.message_content import MessageContent
|
||||
from openai.types.beta.threads.run import Run
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from openai.types.embedding import Embedding as OpenAIEmbedding
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
|
||||
|
@ -47,6 +48,9 @@ FileTypes = Union[
|
|||
]
|
||||
|
||||
|
||||
EmbeddingInput = Union[str, List[str]]
|
||||
|
||||
|
||||
class NotGiven:
|
||||
"""
|
||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||
|
|
|
@ -336,3 +336,29 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
|||
class VertexAICachedContentResponseObject(TypedDict):
|
||||
name: str
|
||||
model: str
|
||||
|
||||
|
||||
class TaskTypeEnum(Enum):
|
||||
TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED"
|
||||
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
|
||||
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
|
||||
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
|
||||
CLASSIFICATION = "CLASSIFICATION"
|
||||
CLUSTERING = "CLUSTERING"
|
||||
QUESTION_ANSWERING = "QUESTION_ANSWERING"
|
||||
FACT_VERIFICATION = "FACT_VERIFICATION"
|
||||
|
||||
|
||||
class VertexAITextEmbeddingsRequestBody(TypedDict, total=False):
|
||||
content: Required[ContentType]
|
||||
taskType: TaskTypeEnum
|
||||
title: str
|
||||
outputDimensionality: int
|
||||
|
||||
|
||||
class ContentEmbeddings(TypedDict):
|
||||
values: List[int]
|
||||
|
||||
|
||||
class VertexAITextEmbeddingsResponseObject(TypedDict):
|
||||
embedding: ContentEmbeddings
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue