Merge pull request #5393 from BerriAI/litellm_gemini_embedding_support

feat(vertex_ai_and_google_ai_studio): Support Google AI Studio Embedding Endpoint
This commit is contained in:
Krish Dholakia 2024-08-28 13:46:28 -07:00 committed by GitHub
commit 996c028127
15 changed files with 481 additions and 71 deletions

View file

@ -126,12 +126,15 @@ from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexAIPartnerModels,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import (
@ -172,6 +175,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
@ -3134,6 +3138,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
@ -3531,6 +3536,26 @@ def embedding(
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "gemini":
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
@ -3571,6 +3596,7 @@ def embedding(
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="vertex_ai",
)
else:
response = vertex_ai_non_gemini.embedding(