mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(batch_embed_content_transformation.py): support google ai studio /batchEmbedContent endpoint
Allows for multiple strings to be given for embedding
This commit is contained in:
parent
4bb59b7b2c
commit
57330d2d0d
8 changed files with 303 additions and 39 deletions
|
@ -41,7 +41,7 @@ def get_supports_system_message(
|
||||||
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
all_gemini_url_modes = Literal["chat", "embedding"]
|
all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
|
||||||
|
|
||||||
|
|
||||||
def _get_vertex_url(
|
def _get_vertex_url(
|
||||||
|
@ -101,4 +101,10 @@ def _get_gemini_url(
|
||||||
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
_gemini_model_name, endpoint, gemini_api_key
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
)
|
)
|
||||||
|
elif mode == "batch_embedding":
|
||||||
|
endpoint = "batchEmbedContents"
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
|
||||||
|
_gemini_model_name, endpoint, gemini_api_key
|
||||||
|
)
|
||||||
|
|
||||||
return url, endpoint
|
return url, endpoint
|
||||||
|
|
|
@ -0,0 +1,167 @@
|
||||||
|
"""
|
||||||
|
Google AI Studio /batchEmbedContents Embeddings Endpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm import EmbeddingResponse
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.llms.openai import EmbeddingInput
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
VertexAIBatchEmbeddingsRequestBody,
|
||||||
|
VertexAIBatchEmbeddingsResponseObject,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
|
from .batch_embed_content_transformation import (
|
||||||
|
process_response,
|
||||||
|
transform_openai_input_gemini_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleBatchEmbeddings(VertexLLM):
|
||||||
|
def batch_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: List[str],
|
||||||
|
print_verbose,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
encoding=None,
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
|
aembedding=False,
|
||||||
|
timeout=300,
|
||||||
|
client=None,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
|
auth_header, url = self._get_token_and_url(
|
||||||
|
model=model,
|
||||||
|
gemini_api_key=api_key,
|
||||||
|
vertex_project=vertex_project,
|
||||||
|
vertex_location=vertex_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
|
stream=None,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
should_use_v1beta1_features=False,
|
||||||
|
mode="batch_embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
sync_handler = client # type: ignore
|
||||||
|
|
||||||
|
optional_params = optional_params or {}
|
||||||
|
|
||||||
|
### TRANSFORMATION ###
|
||||||
|
request_data = transform_openai_input_gemini_content(
|
||||||
|
input=input, model=model, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_data,
|
||||||
|
"api_base": url,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding is True:
|
||||||
|
return self.async_batch_embeddings( # type: ignore
|
||||||
|
model=model,
|
||||||
|
api_base=api_base,
|
||||||
|
url=url,
|
||||||
|
data=request_data,
|
||||||
|
model_response=model_response,
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = sync_handler.post(
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||||
|
|
||||||
|
return process_response(
|
||||||
|
model=model,
|
||||||
|
model_response=model_response,
|
||||||
|
_predictions=_predictions,
|
||||||
|
input=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_batch_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
url: str,
|
||||||
|
data: VertexAIBatchEmbeddingsRequestBody,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
input: EmbeddingInput,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
_httpx_timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = _httpx_timeout
|
||||||
|
else:
|
||||||
|
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
|
||||||
|
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
async_handler = client # type: ignore
|
||||||
|
|
||||||
|
response = await async_handler.post(
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
_json_response = response.json()
|
||||||
|
_predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||||
|
|
||||||
|
return process_response(
|
||||||
|
model=model,
|
||||||
|
model_response=model_response,
|
||||||
|
_predictions=_predictions,
|
||||||
|
input=input,
|
||||||
|
)
|
|
@ -0,0 +1,68 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from litellm import EmbeddingResponse
|
||||||
|
from litellm.types.llms.openai import EmbeddingInput
|
||||||
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
ContentType,
|
||||||
|
EmbedContentRequest,
|
||||||
|
PartType,
|
||||||
|
VertexAIBatchEmbeddingsRequestBody,
|
||||||
|
VertexAIBatchEmbeddingsResponseObject,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import Embedding, Usage
|
||||||
|
from litellm.utils import get_formatted_prompt, token_counter
|
||||||
|
|
||||||
|
from ..common_utils import VertexAIError
|
||||||
|
|
||||||
|
|
||||||
|
def transform_openai_input_gemini_content(
|
||||||
|
input: List[str], model: str, optional_params: dict
|
||||||
|
) -> VertexAIBatchEmbeddingsRequestBody:
|
||||||
|
"""
|
||||||
|
The content to embed. Only the parts.text fields will be counted.
|
||||||
|
"""
|
||||||
|
gemini_model_name = "models/{}".format(model)
|
||||||
|
requests: List[EmbedContentRequest] = []
|
||||||
|
for i in input:
|
||||||
|
request = EmbedContentRequest(
|
||||||
|
model=gemini_model_name,
|
||||||
|
content=ContentType(parts=[PartType(text=i)]),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
|
||||||
|
return VertexAIBatchEmbeddingsRequestBody(requests=requests)
|
||||||
|
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
input: EmbeddingInput,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
model: str,
|
||||||
|
_predictions: VertexAIBatchEmbeddingsResponseObject,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
|
openai_embeddings: List[Embedding] = []
|
||||||
|
for embedding in _predictions["embeddings"]:
|
||||||
|
openai_embedding = Embedding(
|
||||||
|
embedding=embedding["values"],
|
||||||
|
index=0,
|
||||||
|
object="embedding",
|
||||||
|
)
|
||||||
|
openai_embeddings.append(openai_embedding)
|
||||||
|
|
||||||
|
model_response.data = openai_embeddings
|
||||||
|
model_response.model = model
|
||||||
|
|
||||||
|
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||||
|
prompt_tokens = token_counter(model=model, text=input_text)
|
||||||
|
model_response.usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Google AI Studio Embeddings Endpoint
|
Google AI Studio /embedContent Embeddings Endpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
@ -7,7 +7,6 @@ from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm import EmbeddingResponse
|
from litellm import EmbeddingResponse
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.openai import EmbeddingInput
|
from litellm.types.llms.openai import EmbeddingInput
|
||||||
|
@ -15,21 +14,19 @@ from litellm.types.llms.vertex_ai import (
|
||||||
VertexAITextEmbeddingsRequestBody,
|
VertexAITextEmbeddingsRequestBody,
|
||||||
VertexAITextEmbeddingsResponseObject,
|
VertexAITextEmbeddingsResponseObject,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Embedding
|
|
||||||
from litellm.utils import get_formatted_prompt
|
|
||||||
|
|
||||||
from .embeddings_transformation import (
|
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
|
from .embed_content_transformation import (
|
||||||
process_response,
|
process_response,
|
||||||
transform_openai_input_gemini_content,
|
transform_openai_input_gemini_content,
|
||||||
)
|
)
|
||||||
from .vertex_and_google_ai_studio_gemini import VertexLLM
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleEmbeddings(VertexLLM):
|
class GoogleEmbeddings(VertexLLM):
|
||||||
def text_embeddings(
|
def text_embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[list, str],
|
input: str,
|
||||||
print_verbose,
|
print_verbose,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
|
@ -4,8 +4,6 @@ Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embe
|
||||||
Why separate file? Make it easy to see how transformation works
|
Why separate file? Make it easy to see how transformation works
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from litellm import EmbeddingResponse
|
from litellm import EmbeddingResponse
|
||||||
from litellm.types.llms.openai import EmbeddingInput
|
from litellm.types.llms.openai import EmbeddingInput
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
@ -19,19 +17,11 @@ from litellm.utils import get_formatted_prompt, token_counter
|
||||||
from ..common_utils import VertexAIError
|
from ..common_utils import VertexAIError
|
||||||
|
|
||||||
|
|
||||||
def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
|
def transform_openai_input_gemini_content(input: str) -> ContentType:
|
||||||
"""
|
"""
|
||||||
The content to embed. Only the parts.text fields will be counted.
|
The content to embed. Only the parts.text fields will be counted.
|
||||||
"""
|
"""
|
||||||
if isinstance(input, str):
|
return ContentType(parts=[PartType(text=input)])
|
||||||
return ContentType(parts=[PartType(text=input)])
|
|
||||||
elif isinstance(input, list) and len(input) == 1:
|
|
||||||
return ContentType(parts=[PartType(text=input[0])])
|
|
||||||
else:
|
|
||||||
raise VertexAIError(
|
|
||||||
status_code=422,
|
|
||||||
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
|
@ -126,7 +126,10 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
||||||
vertex_ai_anthropic,
|
vertex_ai_anthropic,
|
||||||
vertex_ai_non_gemini,
|
vertex_ai_non_gemini,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.gemini.embeddings_handler import (
|
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
|
||||||
|
GoogleBatchEmbeddings,
|
||||||
|
)
|
||||||
|
from .llms.vertex_ai_and_google_ai_studio.embeddings.embed_content_handler import (
|
||||||
GoogleEmbeddings,
|
GoogleEmbeddings,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
@ -176,6 +179,7 @@ bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
google_embeddings = GoogleEmbeddings()
|
google_embeddings = GoogleEmbeddings()
|
||||||
|
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||||
watsonxai = IBMWatsonXAI()
|
watsonxai = IBMWatsonXAI()
|
||||||
|
@ -3537,21 +3541,38 @@ def embedding(
|
||||||
|
|
||||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||||
|
|
||||||
response = google_embeddings.text_embeddings( # type: ignore
|
if isinstance(input, str):
|
||||||
model=model,
|
response = google_embeddings.text_embeddings( # type: ignore
|
||||||
input=input,
|
model=model,
|
||||||
encoding=encoding,
|
input=input,
|
||||||
logging_obj=logging,
|
encoding=encoding,
|
||||||
optional_params=optional_params,
|
logging_obj=logging,
|
||||||
model_response=EmbeddingResponse(),
|
optional_params=optional_params,
|
||||||
vertex_project=None,
|
model_response=EmbeddingResponse(),
|
||||||
vertex_location=None,
|
vertex_project=None,
|
||||||
vertex_credentials=None,
|
vertex_location=None,
|
||||||
aembedding=aembedding,
|
vertex_credentials=None,
|
||||||
print_verbose=print_verbose,
|
aembedding=aembedding,
|
||||||
custom_llm_provider="gemini",
|
print_verbose=print_verbose,
|
||||||
api_key=gemini_api_key,
|
custom_llm_provider="gemini",
|
||||||
)
|
api_key=gemini_api_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
vertex_project=None,
|
||||||
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
|
aembedding=aembedding,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
custom_llm_provider="gemini",
|
||||||
|
api_key=gemini_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (
|
vertex_ai_project = (
|
||||||
|
|
|
@ -687,19 +687,22 @@ async def test_triton_embeddings():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input", ["good morning from litellm", ["good morning from litellm"]] #
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_embeddings(sync_mode):
|
async def test_gemini_embeddings(sync_mode, input):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = litellm.embedding(
|
response = litellm.embedding(
|
||||||
model="gemini/text-embedding-004",
|
model="gemini/text-embedding-004",
|
||||||
input=["good morning from litellm"],
|
input=input,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="gemini/text-embedding-004",
|
model="gemini/text-embedding-004",
|
||||||
input=["good morning from litellm"],
|
input=input,
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
|
@ -362,3 +362,15 @@ class ContentEmbeddings(TypedDict):
|
||||||
|
|
||||||
class VertexAITextEmbeddingsResponseObject(TypedDict):
|
class VertexAITextEmbeddingsResponseObject(TypedDict):
|
||||||
embedding: ContentEmbeddings
|
embedding: ContentEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedContentRequest(VertexAITextEmbeddingsRequestBody):
|
||||||
|
model: Required[str]
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIBatchEmbeddingsRequestBody(TypedDict, total=False):
|
||||||
|
requests: List[EmbedContentRequest]
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIBatchEmbeddingsResponseObject(TypedDict):
|
||||||
|
embeddings: List[ContentEmbeddings]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue