fix(main.py): simplify to just use /batchEmbedContent

This commit is contained in:
Krrish Dholakia 2024-08-27 21:46:05 -07:00
parent 947801d3ac
commit 7a9f1798ff
6 changed files with 28 additions and 260 deletions

View file

@ -26,7 +26,7 @@ class GoogleBatchEmbeddings(VertexLLM):
def batch_embeddings( def batch_embeddings(
self, self,
model: str, model: str,
input: List[str], input: EmbeddingInput,
print_verbose, print_verbose,
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"], custom_llm_provider: Literal["gemini", "vertex_ai"],

View file

@ -22,20 +22,28 @@ from ..common_utils import VertexAIError
def transform_openai_input_gemini_content( def transform_openai_input_gemini_content(
input: List[str], model: str, optional_params: dict input: EmbeddingInput, model: str, optional_params: dict
) -> VertexAIBatchEmbeddingsRequestBody: ) -> VertexAIBatchEmbeddingsRequestBody:
""" """
The content to embed. Only the parts.text fields will be counted. The content to embed. Only the parts.text fields will be counted.
""" """
gemini_model_name = "models/{}".format(model) gemini_model_name = "models/{}".format(model)
requests: List[EmbedContentRequest] = [] requests: List[EmbedContentRequest] = []
for i in input: if isinstance(input, str):
request = EmbedContentRequest( request = EmbedContentRequest(
model=gemini_model_name, model=gemini_model_name,
content=ContentType(parts=[PartType(text=i)]), content=ContentType(parts=[PartType(text=input)]),
**optional_params **optional_params
) )
requests.append(request) requests.append(request)
else:
for i in input:
request = EmbedContentRequest(
model=gemini_model_name,
content=ContentType(parts=[PartType(text=i)]),
**optional_params
)
requests.append(request)
return VertexAIBatchEmbeddingsRequestBody(requests=requests) return VertexAIBatchEmbeddingsRequestBody(requests=requests)

View file

@ -1,170 +0,0 @@
"""
Google AI Studio /embedContent Embeddings Endpoint
"""
import json
from typing import Literal, Optional, Union
import httpx
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
VertexAITextEmbeddingsRequestBody,
VertexAITextEmbeddingsResponseObject,
)
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .embed_content_transformation import (
process_response,
transform_openai_input_gemini_content,
)
class GoogleEmbeddings(VertexLLM):
def text_embeddings(
self,
model: str,
input: str,
print_verbose,
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=False,
timeout=300,
client=None,
) -> EmbeddingResponse:
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=api_key,
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_credentials=vertex_credentials,
stream=None,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=False,
mode="embedding",
)
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
else:
sync_handler = client # type: ignore
optional_params = optional_params or {}
### TRANSFORMATION ###
content = transform_openai_input_gemini_content(input=input)
request_data: VertexAITextEmbeddingsRequestBody = {
"content": content,
**optional_params,
}
headers = {
"Content-Type": "application/json; charset=utf-8",
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key="",
additional_args={
"complete_input_dict": request_data,
"api_base": url,
"headers": headers,
},
)
if aembedding is True:
return self.async_text_embeddings( # type: ignore
model=model,
api_base=api_base,
url=url,
data=request_data,
model_response=model_response,
timeout=timeout,
headers=headers,
input=input,
)
response = sync_handler.post(
url=url,
headers=headers,
data=json.dumps(request_data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)
async def async_text_embeddings(
self,
model: str,
api_base: Optional[str],
url: str,
data: VertexAITextEmbeddingsRequestBody,
model_response: EmbeddingResponse,
input: EmbeddingInput,
timeout: Optional[Union[float, httpx.Timeout]],
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> EmbeddingResponse:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
else:
async_handler = client # type: ignore
response = await async_handler.post(
url=url,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
_json_response = response.json()
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
return process_response(
model=model,
model_response=model_response,
_predictions=_predictions,
input=input,
)

View file

@ -1,49 +0,0 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embedContent format.
Why separate file? Make it easy to see how transformation works
"""
from litellm import EmbeddingResponse
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
ContentType,
PartType,
VertexAITextEmbeddingsResponseObject,
)
from litellm.types.utils import Embedding, Usage
from litellm.utils import get_formatted_prompt, token_counter
from ..common_utils import VertexAIError
def transform_openai_input_gemini_content(input: str) -> ContentType:
"""
The content to embed. Only the parts.text fields will be counted.
"""
return ContentType(parts=[PartType(text=input)])
def process_response(
input: EmbeddingInput,
model_response: EmbeddingResponse,
model: str,
_predictions: VertexAITextEmbeddingsResponseObject,
) -> EmbeddingResponse:
model_response.data = [
Embedding(
embedding=_predictions["embedding"]["values"],
index=0,
object="embedding",
)
]
model_response.model = model
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
prompt_tokens = token_counter(model=model, text=input_text)
model_response.usage = Usage(
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
)
return model_response

View file

@ -129,9 +129,6 @@ from .llms.vertex_ai_and_google_ai_studio import (
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import ( from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings, GoogleBatchEmbeddings,
) )
from .llms.vertex_ai_and_google_ai_studio.embeddings.embed_content_handler import (
GoogleEmbeddings,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -178,7 +175,6 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM() vertex_chat_completion = VertexLLM()
google_embeddings = GoogleEmbeddings()
google_batch_embeddings = GoogleBatchEmbeddings() google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels() vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_text_to_speech = VertexTextToSpeechAPI() vertex_text_to_speech = VertexTextToSpeechAPI()
@ -3541,38 +3537,21 @@ def embedding(
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
if isinstance(input, str): response = google_batch_embeddings.batch_embeddings( # type: ignore
response = google_embeddings.text_embeddings( # type: ignore model=model,
model=model, input=input,
input=input, encoding=encoding,
encoding=encoding, logging_obj=logging,
logging_obj=logging, optional_params=optional_params,
optional_params=optional_params, model_response=EmbeddingResponse(),
model_response=EmbeddingResponse(), vertex_project=None,
vertex_project=None, vertex_location=None,
vertex_location=None, vertex_credentials=None,
vertex_credentials=None, aembedding=aembedding,
aembedding=aembedding, print_verbose=print_verbose,
print_verbose=print_verbose, custom_llm_provider="gemini",
custom_llm_provider="gemini", api_key=gemini_api_key,
api_key=gemini_api_key, )
)
else:
response = google_batch_embeddings.batch_embeddings( # type: ignore
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
aembedding=aembedding,
print_verbose=print_verbose,
custom_llm_provider="gemini",
api_key=gemini_api_key,
)
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
vertex_ai_project = ( vertex_ai_project = (

View file

@ -149,7 +149,7 @@ def init_rds_client(
# boto3 automatically reads env variables # boto3 automatically reads env variables
client = boto3.client( client = boto3.client(
service_name="bedrock-runtime", service_name="rds",
region_name=region_name, region_name=region_name,
config=config, config=config,
) )