feat(batch_embed_content_transformation.py): support google ai studio /batchEmbedContent endpoint

Allows for multiple strings to be given for embedding
This commit is contained in:
Krrish Dholakia 2024-08-27 19:23:50 -07:00
parent 4bb59b7b2c
commit 57330d2d0d
8 changed files with 303 additions and 39 deletions

View file

@ -41,7 +41,7 @@ def get_supports_system_message(
from typing import Literal, Optional from typing import Literal, Optional
all_gemini_url_modes = Literal["chat", "embedding"] all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
def _get_vertex_url( def _get_vertex_url(
@ -101,4 +101,10 @@ def _get_gemini_url(
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format( url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key _gemini_model_name, endpoint, gemini_api_key
) )
elif mode == "batch_embedding":
endpoint = "batchEmbedContents"
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
_gemini_model_name, endpoint, gemini_api_key
)
return url, endpoint return url, endpoint

View file

@ -0,0 +1,167 @@
"""
Google AI Studio /batchEmbedContents Embeddings Endpoint
"""
import json
from typing import List, Literal, Optional, Union
import httpx
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .batch_embed_content_transformation import (
process_response,
transform_openai_input_gemini_content,
)
class GoogleBatchEmbeddings(VertexLLM):
def batch_embeddings(
self,
model: str,
input: List[str],
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
) -> EmbeddingResponse:
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="batch_embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
### TRANSFORMATION ###
request_data = transform_openai_input_gemini_content(
input=input, model=model, optional_params=optional_params
)
headers = {
"Content-Type": "application/json; charset=utf-8",
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_batch_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=request_data,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
async def async_batch_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: VertexAIBatchEmbeddingsRequestBody,
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
response = await async_handler.post(
url=url,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)

View file

@ -0,0 +1,68 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
Why separate file? Make it easy to see how transformation works
"""
from typing import List
from litellm import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
ContentType,
EmbedContentRequest,
PartType,
VertexAIBatchEmbeddingsRequestBody,
VertexAIBatchEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, Usage
from litellm.utils import get_formatted_prompt, token_counter
from ..common_utils import VertexAIError
def transform_openai_input_gemini_content(
input: List[str], model: str, optional_params: dict
) -> VertexAIBatchEmbeddingsRequestBody:
"""
The content to embed. Only the parts.text fields will be counted.
"""
gemini_model_name = "models/{}".format(model)
requests: List[EmbedContentRequest] = []
for i in input:
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=i)]),
**optional_params
)
requests.append(request)
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
_predictions: VertexAIBatchEmbeddingsResponseObject,
) -> EmbeddingResponse:
openai_embeddings: List[Embedding] = []
for embedding in _predictions["embeddings"]:
openai_embedding = Embedding(
embedding=embedding["values"],
index=0,
object="embedding",
)
openai_embeddings.append(openai_embedding)
model_response.data = openai_embeddings
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response

View file

@ -1,5 +1,5 @@
""" """
Google AI Studio Embeddings Endpoint Google AI Studio /embedContent Embeddings Endpoint
""" """
import json import json
@ -7,7 +7,6 @@ from typing import Literal, Optional, Union
import httpx import httpx
import litellm
from litellm import EmbeddingResponse from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import EmbeddingInput from litellm.types.llms.openai import EmbeddingInput
@ -15,21 +14,19 @@ from litellm.types.llms.vertex_ai import (
VertexAITextEmbeddingsRequestBody, VertexAITextEmbeddingsRequestBody,
VertexAITextEmbeddingsResponseObject, VertexAITextEmbeddingsResponseObject,
) )
from litellm.types.utils import Embedding
from litellm.utils import get_formatted_prompt
from .embeddings_transformation import ( from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .embed_content_transformation import (
process_response, process_response,
transform_openai_input_gemini_content, transform_openai_input_gemini_content,
) )
from .vertex_and_google_ai_studio_gemini import VertexLLM
class GoogleEmbeddings(VertexLLM): class GoogleEmbeddings(VertexLLM):
def text_embeddings( def text_embeddings(
self, self,
model: str, model: str,
input: Union[list, str], input: str,
print_verbose, print_verbose,
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"], custom_llm_provider: Literal["gemini", "vertex_ai"],

View file

@ -4,8 +4,6 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe
Why separate file? Make it easy to see how transformation works Why separate file? Make it easy to see how transformation works
""" """
from typing import List
from litellm import EmbeddingResponse from litellm import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import ( from litellm.types.llms.vertex_ai import (
@ -19,19 +17,11 @@ from litellm.utils import get_formatted_prompt, token_counter
from ..common_utils import VertexAIError from ..common_utils import VertexAIError
def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType: def transform_openai_input_gemini_content(input: str) -> ContentType:
""" """
The content to embed. Only the parts.text fields will be counted. The content to embed. Only the parts.text fields will be counted.
""" """
if isinstance(input, str):
return ContentType(parts=[PartType(text=input)]) return ContentType(parts=[PartType(text=input)])
elif isinstance(input, list) and len(input) == 1:
return ContentType(parts=[PartType(text=input[0])])
else:
raise VertexAIError(
status_code=422,
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
)
def process_response( def process_response(

View file

@ -126,7 +126,10 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic, vertex_ai_anthropic,
vertex_ai_non_gemini, vertex_ai_non_gemini,
) )
from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import ( from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.embeddings.embed_content_handler import (
GoogleEmbeddings, GoogleEmbeddings,
) )
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
@ -176,6 +179,7 @@ bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
google_embeddings = GoogleEmbeddings() google_embeddings = GoogleEmbeddings()
google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI() vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI() watsonxai = IBMWatsonXAI()
@ -3537,6 +3541,7 @@ def embedding(
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
if isinstance(input, str):
response = google_embeddings.text_embeddings( # type: ignore response = google_embeddings.text_embeddings( # type: ignore
model=model, model=model,
input=input, input=input,
@ -3552,6 +3557,22 @@ def embedding(
custom_llm_provider="gemini", custom_llm_provider="gemini",
api_key=gemini_api_key, api_key=gemini_api_key,
) )
else:
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (

View file

@ -687,19 +687,22 @@ async def test_triton_embeddings():
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"input", ["good morning from litellm", ["good morning from litellm"]] #
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_embeddings(sync_mode): async def test_gemini_embeddings(sync_mode, input):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: if sync_mode:
response = litellm.embedding( response = litellm.embedding(
model="gemini/text-embedding-004", model="gemini/text-embedding-004",
input=["good morning from litellm"], input=input,
) )
else: else:
response = await litellm.aembedding( response = await litellm.aembedding(
model="gemini/text-embedding-004", model="gemini/text-embedding-004",
input=["good morning from litellm"], input=input,
) )
print(f"response: {response}") print(f"response: {response}")

View file

@ -362,3 +362,15 @@ class ContentEmbeddings(TypedDict):
class VertexAITextEmbeddingsResponseObject(TypedDict): class VertexAITextEmbeddingsResponseObject(TypedDict):
embedding: ContentEmbeddings embedding: ContentEmbeddings
class EmbedContentRequest(VertexAITextEmbeddingsRequestBody):
model: Required[str]
class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False):
requests: List[EmbedContentRequest]
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
embeddings: List[ContentEmbeddings]