fix(main.py): simplify to just use /batchEmbedContent

This commit is contained in:
Krrish Dholakia 2024-08-27 21:46:05 -07:00
parent 947801d3ac
commit 7a9f1798ff
6 changed files with 28 additions and 260 deletions

View file

@ -129,9 +129,6 @@ from .llms.vertex_ai_and_google_ai_studio import (
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.embeddings.embed_content_handler import (
GoogleEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -178,7 +175,6 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
google_embeddings = GoogleEmbeddings()
google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
@ -3541,38 +3537,21 @@ def embedding(
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
if isinstance(input, str):
response = google_embeddings.text_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
else:
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (