mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(main.py): simplify to just use /batchEmbedContent
This commit is contained in:
parent
947801d3ac
commit
7a9f1798ff
6 changed files with 28 additions and 260 deletions
|
@ -26,7 +26,7 @@ class GoogleBatchEmbeddings(VertexLLM):
|
|||
def batch_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: List[str],
|
||||
input: EmbeddingInput,
|
||||
print_verbose,
|
||||
model_response: EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
|
|
|
@ -22,13 +22,21 @@ from ..common_utils import VertexAIError
|
|||
|
||||
|
||||
def transform_openai_input_gemini_content(
|
||||
input: List[str], model: str, optional_params: dict
|
||||
input: EmbeddingInput, model: str, optional_params: dict
|
||||
) -> VertexAIBatchEmbeddingsRequestBody:
|
||||
"""
|
||||
The content to embed. Only the parts.text fields will be counted.
|
||||
"""
|
||||
gemini_model_name = "models/{}".format(model)
|
||||
requests: List[EmbedContentRequest] = []
|
||||
if isinstance(input, str):
|
||||
request = EmbedContentRequest(
|
||||
model=gemini_model_name,
|
||||
content=ContentType(parts=[PartType(text=input)]),
|
||||
**optional_params
|
||||
)
|
||||
requests.append(request)
|
||||
else:
|
||||
for i in input:
|
||||
request = EmbedContentRequest(
|
||||
model=gemini_model_name,
|
||||
|
|
|
@ -1,170 +0,0 @@
|
|||
"""
|
||||
Google AI Studio /embedContent Embeddings Endpoint
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import EmbeddingResponse
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
VertexAITextEmbeddingsRequestBody,
|
||||
VertexAITextEmbeddingsResponseObject,
|
||||
)
|
||||
|
||||
from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .embed_content_transformation import (
|
||||
process_response,
|
||||
transform_openai_input_gemini_content,
|
||||
)
|
||||
|
||||
|
||||
class GoogleEmbeddings(VertexLLM):
|
||||
def text_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
print_verbose,
|
||||
model_response: EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=False,
|
||||
timeout=300,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=api_key,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=False,
|
||||
mode="embedding",
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
optional_params = optional_params or {}
|
||||
|
||||
### TRANSFORMATION ###
|
||||
content = transform_openai_input_gemini_content(input=input)
|
||||
|
||||
request_data: VertexAITextEmbeddingsRequestBody = {
|
||||
"content": content,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_data,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_text_embeddings( # type: ignore
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
url=url,
|
||||
data=request_data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
input=input,
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
||||
|
||||
async def async_text_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
url: str,
|
||||
data: VertexAITextEmbeddingsRequestBody,
|
||||
model_response: EmbeddingResponse,
|
||||
input: EmbeddingInput,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
async_handler = client # type: ignore
|
||||
|
||||
response = await async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
_predictions = VertexAITextEmbeddingsResponseObject(**_json_response) # type: ignore
|
||||
|
||||
return process_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
_predictions=_predictions,
|
||||
input=input,
|
||||
)
|
|
@ -1,49 +0,0 @@
|
|||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /embedContent format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from litellm import EmbeddingResponse
|
||||
from litellm.types.llms.openai import EmbeddingInput
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
ContentType,
|
||||
PartType,
|
||||
VertexAITextEmbeddingsResponseObject,
|
||||
)
|
||||
from litellm.types.utils import Embedding, Usage
|
||||
from litellm.utils import get_formatted_prompt, token_counter
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
|
||||
|
||||
def transform_openai_input_gemini_content(input: str) -> ContentType:
|
||||
"""
|
||||
The content to embed. Only the parts.text fields will be counted.
|
||||
"""
|
||||
return ContentType(parts=[PartType(text=input)])
|
||||
|
||||
|
||||
def process_response(
|
||||
input: EmbeddingInput,
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
_predictions: VertexAITextEmbeddingsResponseObject,
|
||||
) -> EmbeddingResponse:
|
||||
model_response.data = [
|
||||
Embedding(
|
||||
embedding=_predictions["embedding"]["values"],
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
]
|
||||
|
||||
model_response.model = model
|
||||
|
||||
input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
|
||||
prompt_tokens = token_counter(model=model, text=input_text)
|
||||
model_response.usage = Usage(
|
||||
prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
|
||||
)
|
||||
|
||||
return model_response
|
|
@ -129,9 +129,6 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
|
||||
GoogleBatchEmbeddings,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.embeddings.embed_content_handler import (
|
||||
GoogleEmbeddings,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
@ -178,7 +175,6 @@ triton_chat_completions = TritonChatCompletion()
|
|||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
google_embeddings = GoogleEmbeddings()
|
||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
|
@ -3541,23 +3537,6 @@ def embedding(
|
|||
|
||||
gemini_api_key = api_key or get_secret("GEMINI_API_KEY") or litellm.api_key
|
||||
|
||||
if isinstance(input, str):
|
||||
response = google_embeddings.text_embeddings( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
custom_llm_provider="gemini",
|
||||
api_key=gemini_api_key,
|
||||
)
|
||||
else:
|
||||
response = google_batch_embeddings.batch_embeddings( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
|
|
|
@ -149,7 +149,7 @@ def init_rds_client(
|
|||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue