(performance improvement - vertex embeddings) ~111.11% faster (#6000)

* use vertex llm as base class for embeddings

* use correct vertex class in main.py

* set_headers in vertex llm base

* add types for vertex embedding requests

* add embedding handler for vertex

* use async mode for vertex embedding tests

* use vertexAI textEmbeddingConfig

* fix linting

* add sync and async mode testing for vertex ai embeddings
This commit is contained in:
Ishaan Jaff 2024-10-01 14:16:21 -07:00 committed by GitHub
parent 18a28ef977
commit eef9bad9a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 497 additions and 300 deletions

View file

@ -134,8 +134,8 @@ from .llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings import (
embedding_handler as vertex_ai_embedding_handler,
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
VertexEmbedding,
)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
@ -185,6 +185,7 @@ bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
vertex_chat_completion = VertexLLM()
vertex_embedding = VertexEmbedding()
vertex_multimodal_embedding = VertexMultimodalEmbedding()
vertex_image_generation = VertexImageGeneration()
google_batch_embeddings = GoogleBatchEmbeddings()
@ -2980,7 +2981,7 @@ def batch_completion(
deployment_id=None,
request_timeout: Optional[int] = None,
timeout: Optional[int] = 600,
max_workers:Optional[int]= 100,
max_workers: Optional[int] = 100,
# Optional liteLLM function params
**kwargs,
):
@ -3711,21 +3712,21 @@ def embedding(
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
or get_secret("VERTEX_PROJECT")
or get_secret_str("VERTEXAI_PROJECT")
or get_secret_str("VERTEX_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret("VERTEX_LOCATION")
or get_secret_str("VERTEXAI_LOCATION")
or get_secret_str("VERTEX_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
or get_secret("VERTEX_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
or get_secret_str("VERTEX_CREDENTIALS")
)
if (
@ -3750,7 +3751,7 @@ def embedding(
custom_llm_provider="vertex_ai",
)
else:
response = vertex_ai_embedding_handler.embedding(
response = vertex_embedding.embedding(
model=model,
input=input,
encoding=encoding,
@ -3760,6 +3761,8 @@ def embedding(
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
custom_llm_provider="vertex_ai",
timeout=timeout,
aembedding=aembedding,
print_verbose=print_verbose,
)