forked from phoenix/litellm-mirror
fix(embeddings_handler.py): initial working commit for google ai studio text embeddings /embedContent endpoint
This commit is contained in:
parent
77e6da78a1
commit
5b29ddd2a6
6 changed files with 111 additions and 48 deletions
|
@ -10,7 +10,14 @@ import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import EmbeddingResponse
|
from litellm import EmbeddingResponse
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
VertexAITextEmbeddingsRequestBody,
|
||||||
|
VertexAITextEmbeddingsResponseObject,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import Embedding
|
||||||
|
from litellm.utils import get_formatted_prompt
|
||||||
|
|
||||||
|
from .embeddings_transformation import transform_openai_input_gemini_content
|
||||||
from .vertex_and_google_ai_studio_gemini import VertexLLM
|
from .vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +41,7 @@ class GoogleEmbeddings(VertexLLM):
|
||||||
timeout=300,
|
timeout=300,
|
||||||
client=None,
|
client=None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
return model_response
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
model=model,
|
model=model,
|
||||||
gemini_api_key=api_key,
|
gemini_api_key=api_key,
|
||||||
|
@ -63,59 +70,58 @@ class GoogleEmbeddings(VertexLLM):
|
||||||
|
|
||||||
optional_params = optional_params or {}
|
optional_params = optional_params or {}
|
||||||
|
|
||||||
# request_data = VertexMultimodalEmbeddingRequest()
|
### TRANSFORMATION ###
|
||||||
|
content = transform_openai_input_gemini_content(input=input)
|
||||||
|
|
||||||
# if "instances" in optional_params:
|
request_data: VertexAITextEmbeddingsRequestBody = {
|
||||||
# request_data["instances"] = optional_params["instances"]
|
"content": content,
|
||||||
# elif isinstance(input, list):
|
**optional_params,
|
||||||
# request_data["instances"] = input
|
}
|
||||||
# else:
|
|
||||||
# # construct instances
|
|
||||||
# vertex_request_instance = Instance(**optional_params)
|
|
||||||
|
|
||||||
# if isinstance(input, str):
|
headers = {
|
||||||
# vertex_request_instance["text"] = input
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
}
|
||||||
|
|
||||||
# request_data["instances"] = [vertex_request_instance]
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_data,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# headers = {
|
if aembedding is True:
|
||||||
# "Content-Type": "application/json; charset=utf-8",
|
pass
|
||||||
# "Authorization": f"Bearer {auth_header}",
|
|
||||||
# }
|
|
||||||
|
|
||||||
# ## LOGGING
|
response = sync_handler.post(
|
||||||
# logging_obj.pre_call(
|
url=url,
|
||||||
# input=input,
|
headers=headers,
|
||||||
# api_key="",
|
data=json.dumps(request_data),
|
||||||
# additional_args={
|
)
|
||||||
# "complete_input_dict": request_data,
|
|
||||||
# "api_base": url,
|
|
||||||
# "headers": headers,
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if aembedding is True:
|
if response.status_code != 200:
|
||||||
# pass
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
# response = sync_handler.post(
|
_json_response = response.json()
|
||||||
# url=url,
|
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||||
# headers=headers,
|
|
||||||
# data=json.dumps(request_data),
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if response.status_code != 200:
|
model_response.data = [
|
||||||
# raise Exception(f"Error: {response.status_code} {response.text}")
|
Embedding(
|
||||||
|
embedding=_predictions["embedding"]["values"],
|
||||||
|
index=0,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# _json_response = response.json()
|
model_response.model = model
|
||||||
# if "predictions" not in _json_response:
|
|
||||||
# raise litellm.InternalServerError(
|
|
||||||
# message=f"embedding response does not contain 'predictions', got {_json_response}",
|
|
||||||
# llm_provider="vertex_ai",
|
|
||||||
# model=model,
|
|
||||||
# )
|
|
||||||
# _predictions = _json_response["predictions"]
|
|
||||||
|
|
||||||
# model_response.data = _predictions
|
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||||
# model_response.model = model
|
prompt_tokens = litellm.token_counter(model=model, text=input_text)
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# return model_response
|
return model_response
|
||||||
|
|
|
@ -3,3 +3,25 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe
|
||||||
|
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import EmbeddingInput
|
||||||
|
from litellm.types.llms.vertex_ai import ContentType, PartType
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
|
||||||
|
|
||||||
|
def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
|
||||||
|
"""
|
||||||
|
The content to embed. Only the parts.text fields will be counted.
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
return ContentType(parts=[PartType(text=input)])
|
||||||
|
elif isinstance(input, list) and len(input) == 1:
|
||||||
|
return ContentType(parts=[PartType(text=input[0])])
|
||||||
|
else:
|
||||||
|
raise VertexAIError(
|
||||||
|
status_code=422,
|
||||||
|
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
|
||||||
|
)
|
||||||
|
|
|
@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
||||||
vertex_ai_anthropic,
|
vertex_ai_anthropic,
|
||||||
vertex_ai_non_gemini,
|
vertex_ai_non_gemini,
|
||||||
)
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import (
|
||||||
|
GoogleEmbeddings,
|
||||||
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
|
google_embeddings = GoogleEmbeddings()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||||
watsonxai = IBMWatsonXAI()
|
watsonxai = IBMWatsonXAI()
|
||||||
|
@ -3533,7 +3537,7 @@ def embedding(
|
||||||
|
|
||||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||||
|
|
||||||
response = vertex_chat_completion.multimodal_embedding( # type: ignore
|
response = google_embeddings.text_embeddings( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
|
|
@ -697,7 +697,8 @@ async def test_gemini_embeddings():
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
# stubbed endpoint is setup to return this
|
# stubbed endpoint is setup to return this
|
||||||
assert response.data[0]["embedding"] == [0.1, 0.2]
|
assert isinstance(response.data[0]["embedding"], list)
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
|
||||||
from openai.types.beta.threads.message_content import MessageContent
|
from openai.types.beta.threads.message_content import MessageContent
|
||||||
from openai.types.beta.threads.run import Run
|
from openai.types.beta.threads.run import Run
|
||||||
from openai.types.chat import ChatCompletionChunk
|
from openai.types.chat import ChatCompletionChunk
|
||||||
|
from openai.types.embedding import Embedding as OpenAIEmbedding
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Dict, Required, TypedDict, override
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
@ -47,6 +48,9 @@ FileTypes = Union[
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
EmbeddingInput = Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
class NotGiven:
|
class NotGiven:
|
||||||
"""
|
"""
|
||||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||||
|
|
|
@ -336,3 +336,29 @@ class VertexMultimodalEmbeddingRequest(TypedDict, total=False):
|
||||||
class VertexAICachedContentResponseObject(TypedDict):
|
class VertexAICachedContentResponseObject(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
model: str
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTypeEnum(Enum):
|
||||||
|
TASK_TYPE_UNSPECIFIED = "TASK_TYPE_UNSPECIFIED"
|
||||||
|
RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
|
||||||
|
RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
|
||||||
|
SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
|
||||||
|
CLASSIFICATION = "CLASSIFICATION"
|
||||||
|
CLUSTERING = "CLUSTERING"
|
||||||
|
QUESTION_ANSWERING = "QUESTION_ANSWERING"
|
||||||
|
FACT_VERIFICATION = "FACT_VERIFICATION"
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAITextEmbeddingsRequestBody(TypedDict, total=False):
|
||||||
|
content: Required[ContentType]
|
||||||
|
taskType: TaskTypeEnum
|
||||||
|
title: str
|
||||||
|
outputDimensionality: int
|
||||||
|
|
||||||
|
|
||||||
|
class ContentEmbeddings(TypedDict):
|
||||||
|
values: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAITextEmbeddingsResponseObject(TypedDict):
|
||||||
|
embedding: ContentEmbeddings
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue