feat(embeddings_handler.py): support async gemini embeddings

This commit is contained in:
Krrish Dholakia 2024-08-27 18:31:57 -07:00
parent 6a483a1908
commit 4bb59b7b2c
3 changed files with 109 additions and 24 deletions

View file

@ -9,7 +9,8 @@ import httpx
import litellm import litellm
from litellm import EmbeddingResponse from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import ( from litellm.types.llms.vertex_ai import (
VertexAITextEmbeddingsRequestBody, VertexAITextEmbeddingsRequestBody,
VertexAITextEmbeddingsResponseObject, VertexAITextEmbeddingsResponseObject,
@ -17,7 +18,10 @@ from litellm.types.llms.vertex_ai import (
from litellm.types.utils import Embedding from litellm.types.utils import Embedding
from litellm.utils import get_formatted_prompt from litellm.utils import get_formatted_prompt
from .embeddings_transformation import transform_openai_input_gemini_content from .embeddings_transformation import (
process_response,
transform_openai_input_gemini_content,
)
from .vertex_and_google_ai_studio_gemini import VertexLLM from .vertex_and_google_ai_studio_gemini import VertexLLM
@ -94,7 +98,16 @@ class GoogleEmbeddings(VertexLLM):
) )
if aembedding is True: if aembedding is True:
pass return self.async_text_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=request_data,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
)
response = sync_handler.post( response = sync_handler.post(
url=url, url=url,
@ -108,20 +121,53 @@ class GoogleEmbeddings(VertexLLM):
_json_response = response.json() _json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore _predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
model_response.data = [ return process_response(
Embedding( model=model,
embedding=_predictions["embedding"]["values"], model_response=model_response,
index=0, _predictions=_predictions,
object="embedding", input=input,
)
]
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = litellm.token_counter(model=model, text=input_text)
model_response.usage = litellm.Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
) )
return model_response async def async_text_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: VertexAITextEmbeddingsRequestBody,
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
response = await async_handler.post(
url=url,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)

View file

@ -6,8 +6,15 @@ Why separate file? Make it easy to see how transformation works
from typing import List from typing import List
from litellm import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import ContentType, PartType from litellm.types.llms.vertex_ai import (
ContentType,
PartType,
VertexAITextEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, Usage
from litellm.utils import get_formatted_prompt, token_counter
from ..common_utils import VertexAIError from ..common_utils import VertexAIError
@ -25,3 +32,28 @@ def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
status_code=422, status_code=422,
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues", message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
) )
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
_predictions: VertexAITextEmbeddingsResponseObject,
) -> EmbeddingResponse:
model_response.data = [
Embedding(
embedding=_predictions["embedding"]["values"],
index=0,
object="embedding",
)
]
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response

View file

@ -686,14 +686,21 @@ async def test_triton_embeddings():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_embeddings(): async def test_gemini_embeddings(sync_mode):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
response = await litellm.aembedding( if sync_mode:
model="gemini/text-embedding-004", response = litellm.embedding(
input=["good morning from litellm"], model="gemini/text-embedding-004",
) input=["good morning from litellm"],
)
else:
response = await litellm.aembedding(
model="gemini/text-embedding-004",
input=["good morning from litellm"],
)
print(f"response: {response}") print(f"response: {response}")
# stubbed endpoint is setup to return this # stubbed endpoint is setup to return this