feat(embeddings_handler.py): support async gemini embeddings

This commit is contained in:
Krrish Dholakia 2024-08-27 18:31:57 -07:00
parent 6a483a1908
commit 4bb59b7b2c
3 changed files with 109 additions and 24 deletions

View file

@ -9,7 +9,8 @@ import httpx
import litellm
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
VertexAITextEmbeddingsRequestBody,
VertexAITextEmbeddingsResponseObject,
@ -17,7 +18,10 @@ from litellm.types.llms.vertex_ai import (
from litellm.types.utils import Embedding
from litellm.utils import get_formatted_prompt
from .embeddings_transformation import transform_openai_input_gemini_content
from .embeddings_transformation import (
process_response,
transform_openai_input_gemini_content,
)
from .vertex_and_google_ai_studio_gemini import VertexLLM
@ -94,7 +98,16 @@ class GoogleEmbeddings(VertexLLM):
)
if aembedding is True:
pass
return self.async_text_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=request_data,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
)
response = sync_handler.post(
url=url,
@ -108,20 +121,53 @@ class GoogleEmbeddings(VertexLLM):
_json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
model_response.data = [
Embedding(
embedding=_predictions["embedding"]["values"],
index=0,
object="embedding",
)
]
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = litellm.token_counter(model=model, text=input_text)
model_response.usage = litellm.Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
return model_response
async def async_text_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: VertexAITextEmbeddingsRequestBody,
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
response = await async_handler.post(
url=url,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)

View file

@ -6,8 +6,15 @@ Why separate file? Make it easy to see how transformation works
from typing import List
from litellm import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import ContentType, PartType
from litellm.types.llms.vertex_ai import (
ContentType,
PartType,
VertexAITextEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, Usage
from litellm.utils import get_formatted_prompt, token_counter
from ..common_utils import VertexAIError
@ -25,3 +32,28 @@ def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
status_code=422,
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
)
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
_predictions: VertexAITextEmbeddingsResponseObject,
) -> EmbeddingResponse:
model_response.data = [
Embedding(
embedding=_predictions["embedding"]["values"],
index=0,
object="embedding",
)
]
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response

View file

@ -686,14 +686,21 @@ async def test_triton_embeddings():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_gemini_embeddings():
async def test_gemini_embeddings(sync_mode):
try:
litellm.set_verbose = True
response = await litellm.aembedding(
model="gemini/text-embedding-004",
input=["good morning from litellm"],
)
if sync_mode:
response = litellm.embedding(
model="gemini/text-embedding-004",
input=["good morning from litellm"],
)
else:
response = await litellm.aembedding(
model="gemini/text-embedding-004",
input=["good morning from litellm"],
)
print(f"response: {response}")
# stubbed endpoint is setup to return this