mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(embeddings_handler.py): support async gemini embeddings
This commit is contained in:
parent
6a483a1908
commit
4bb59b7b2c
3 changed files with 109 additions and 24 deletions
|
@ -9,7 +9,8 @@ import httpx
|
|||
|
||||
import litellm
|
||||
from litellm import EmbeddingResponse
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VertexAITextEmbeddingsRequestBody,
|
||||
VertexAITextEmbeddingsResponseObject,
|
||||
|
@ -17,7 +18,10 @@ from litellm.types.llms.vertex_ai import (
|
|||
from litellm.types.utils import Embedding
|
||||
from litellm.utils import get_formatted_prompt
|
||||
|
||||
from .embeddings_transformation import transform_openai_input_gemini_content
|
||||
from .embeddings_transformation import (
|
||||
process_response,
|
||||
transform_openai_input_gemini_content,
|
||||
)
|
||||
from .vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
|
||||
|
||||
|
@ -94,7 +98,16 @@ class GoogleEmbeddings(VertexLLM):
|
|||
)
|
||||
|
||||
if aembedding is True:
|
||||
pass
|
||||
return self.async_text_embeddings( # type: ignore
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
url=url,
|
||||
data=request_data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
input=input,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
|
@ -108,20 +121,53 @@ class GoogleEmbeddings(VertexLLM):
|
|||
_json_response = response.json()
|
||||
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
model_response.data = [
|
||||
Embedding(
|
||||
embedding=_predictions["embedding"]["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
]
|
||||
|
||||
model_response.model = model
|
||||
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = litellm.token_counter(model=model, text=input_text)
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
return model_response
|
||||
async def async_text_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
url: str,
|
||||
data: VertexAITextEmbeddingsRequestBody,
|
||||
model_response: EmbeddingResponse,
|
||||
input: EmbeddingInput,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
async_handler = client # type: ignore
|
||||
|
||||
response = await async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
|
|
@ -6,8 +6,15 @@ Why separate file? Make it easy to see how transformation works
|
|||
|
||||
from typing import List
|
||||
|
||||
from litellm import EmbeddingResponse
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import ContentType, PartType
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
PartType,
|
||||
VertexAITextEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import Embedding, Usage
|
||||
from litellm.utils import get_formatted_prompt, token_counter
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
|
||||
|
@ -25,3 +32,28 @@ def transform_openai_input_gemini_content(input: EmbeddingInput) -> ContentType:
|
|||
status_code=422,
|
||||
message="/embedContent only generates a single text embedding vector. File an issue, to add support for /batchEmbedContent - https://github.com/BerriAI/litellm/issues",
|
||||
)
|
||||
|
||||
|
||||
def process_response(
|
||||
input: EmbeddingInput,
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
_predictions: VertexAITextEmbeddingsResponseObject,
|
||||
) -> EmbeddingResponse:
|
||||
model_response.data = [
|
||||
Embedding(
|
||||
embedding=_predictions["embedding"]["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
]
|
||||
|
||||
model_response.model = model
|
||||
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = token_counter(model=model, text=input_text)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -686,14 +686,21 @@ async def test_triton_embeddings():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_embeddings():
|
||||
async def test_gemini_embeddings(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(
|
||||
model="gemini/text-embedding-004",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
if sync_mode:
|
||||
response = litellm.embedding(
|
||||
model="gemini/text-embedding-004",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model="gemini/text-embedding-004",
|
||||
input=["good morning from litellm"],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
# stubbed endpoint is setup to return this
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue