mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #5455 from BerriAI/litellm_vtx_add_input_type_mapping
[Feat] Vertex embeddings - map `input_type` to `text_type`
This commit is contained in:
commit
3a72197e77
10 changed files with 398 additions and 281 deletions
|
@ -1531,28 +1531,103 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
|||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||
|
||||
### Advanced Use `task_type` and `title` (Vertex Specific Params)
|
||||
### Supported OpenAI (Unified) Params
|
||||
|
||||
👉 `task_type` and `title` are vertex specific params
|
||||
| [param](../embedding/supported_embedding.md#input-params-for-litellmembedding) | type | [vertex equivalent](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api) |
|
||||
|-------|-------------|--------------------|
|
||||
| `input` | **string or List[string]** | `instances` |
|
||||
| `dimensions` | **int** | `output_dimensionality` |
|
||||
| `input_type` | **Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]** | `task_type` |
|
||||
|
||||
LiteLLM Supported Vertex Specific Params
|
||||
#### Usage with OpenAI (Unified) Params
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
auto_truncate: Optional[bool] = None
|
||||
task_type: Optional[Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]] = None
|
||||
title: Optional[str] = None # The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||
response = litellm.embedding(
|
||||
model="vertex_ai/text-embedding-004",
|
||||
input=["good morning from litellm", "gm"]
|
||||
input_type = "RETRIEVAL_DOCUMENT",
|
||||
dimensions=1,
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM PROXY">
|
||||
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-004",
|
||||
input = ["good morning from litellm", "gm"],
|
||||
dimensions=1,
|
||||
extra_body = {
|
||||
"input_type": "RETRIEVAL_QUERY",
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
### Supported Vertex Specific Params
|
||||
|
||||
| param | type |
|
||||
|-------|-------------|
|
||||
| `auto_truncate` | **bool** |
|
||||
| `task_type` | **Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]** |
|
||||
| `title` | **str** |
|
||||
|
||||
#### Usage with Vertex Specific Params (Use `task_type` and `title`)
|
||||
|
||||
You can pass any vertex specific params to the embedding model. Just pass them to the embedding function like this:
|
||||
|
||||
[Relevant Vertex AI doc with all embedding params](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#request_body)
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
**Example Usage with LiteLLM**
|
||||
```python
|
||||
response = litellm.embedding(
|
||||
model="vertex_ai/text-embedding-004",
|
||||
input=["good morning from litellm", "gm"]
|
||||
task_type = "RETRIEVAL_DOCUMENT",
|
||||
title = "test",
|
||||
dimensions=1,
|
||||
auto_truncate=True,
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="LiteLLM PROXY">
|
||||
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
|
||||
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-004",
|
||||
input = ["good morning from litellm", "gm"],
|
||||
dimensions=1,
|
||||
extra_body = {
|
||||
"task_type": "RETRIEVAL_QUERY",
|
||||
"auto_truncate": True,
|
||||
"title": "test",
|
||||
}
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## **Multi-Modal Embeddings**
|
||||
|
||||
|
|
|
@ -861,7 +861,7 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem
|
|||
GoogleAIStudioGeminiConfig,
|
||||
VertexAIConfig,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
||||
VertexAITextEmbeddingConfig,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||
|
|
|
@ -1096,269 +1096,3 @@ async def async_streaming(
|
|||
)
|
||||
|
||||
return streamwrapper
|
||||
|
||||
|
||||
class VertexAITextEmbeddingConfig(BaseModel):
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||
|
||||
Args:
|
||||
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
|
||||
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
|
||||
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||
"""
|
||||
|
||||
auto_truncate: Optional[bool] = None
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_truncate: Optional[bool] = None,
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"dimensions",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "dimensions":
|
||||
optional_params["output_dimensionality"] = value
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=False,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
|
||||
import google.auth # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
try:
|
||||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location, credentials=creds
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=401, message=str(e))
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if optional_params is not None and isinstance(optional_params, dict):
|
||||
if optional_params.get("task_type") or optional_params.get("title"):
|
||||
# if user passed task_type or title, cast to TextEmbeddingInput
|
||||
_task_type = optional_params.pop("task_type", None)
|
||||
_title = optional_params.pop("title", None)
|
||||
input = [
|
||||
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
|
||||
for x in input
|
||||
]
|
||||
|
||||
try:
|
||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=422, message=str(e))
|
||||
|
||||
if aembedding == True:
|
||||
return async_embedding(
|
||||
model=model,
|
||||
client=llm_model,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
_input_dict = {"texts": input, **optional_params}
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||
## LOGGING PRE-CALL
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = llm_model.get_embeddings(**_input_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING POST-CALL
|
||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
input_tokens += embedding.statistics.token_count # type: ignore
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
async def async_embedding(
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
client=None,
|
||||
):
|
||||
"""
|
||||
Async embedding implementation
|
||||
"""
|
||||
_input_dict = {"texts": input, **optional_params}
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||
## LOGGING PRE-CALL
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = await client.get_embeddings_async(**_input_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING POST-CALL
|
||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
input_tokens += embedding.statistics.token_count
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
import json
|
||||
import os
|
||||
import types
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_non_gemini import (
|
||||
VertexAIError,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.utils import Usage
|
||||
|
||||
|
||||
class VertexAITextEmbeddingConfig(BaseModel):
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||
|
||||
Args:
|
||||
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
|
||||
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
|
||||
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||
"""
|
||||
|
||||
auto_truncate: Optional[bool] = None
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_truncate: Optional[bool] = None,
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, kwargs: dict
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "dimensions":
|
||||
optional_params["output_dimensionality"] = value
|
||||
|
||||
if "input_type" in kwargs:
|
||||
optional_params["task_type"] = kwargs.pop("input_type")
|
||||
return optional_params, kwargs
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
print_verbose,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
aembedding=False,
|
||||
):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||
)
|
||||
|
||||
import google.auth # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
try:
|
||||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
json_obj = json.loads(vertex_credentials)
|
||||
|
||||
creds = google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
else:
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location, credentials=creds
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=401, message=str(e))
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if optional_params is not None and isinstance(optional_params, dict):
|
||||
if optional_params.get("task_type") or optional_params.get("title"):
|
||||
# if user passed task_type or title, cast to TextEmbeddingInput
|
||||
_task_type = optional_params.pop("task_type", None)
|
||||
_title = optional_params.pop("title", None)
|
||||
input = [
|
||||
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
|
||||
for x in input
|
||||
]
|
||||
|
||||
try:
|
||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=422, message=str(e))
|
||||
|
||||
if aembedding == True:
|
||||
return async_embedding(
|
||||
model=model,
|
||||
client=llm_model,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
_input_dict = {"texts": input, **optional_params}
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||
## LOGGING PRE-CALL
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = llm_model.get_embeddings(**_input_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING POST-CALL
|
||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
input_tokens += embedding.statistics.token_count # type: ignore
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
async def async_embedding(
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
client=None,
|
||||
):
|
||||
"""
|
||||
Async embedding implementation
|
||||
"""
|
||||
_input_dict = {"texts": input, **optional_params}
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||
## LOGGING PRE-CALL
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = await client.get_embeddings_async(**_input_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING POST-CALL
|
||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||
## Populate OpenAI compliant dictionary
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding.values,
|
||||
}
|
||||
)
|
||||
input_tokens += embedding.statistics.token_count
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
|
@ -125,12 +125,12 @@ from .llms.vertex_ai_and_google_ai_studio import (
|
|||
vertex_ai_anthropic,
|
||||
vertex_ai_non_gemini,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.embeddings.batch_embed_content_handler import (
|
||||
GoogleBatchEmbeddings,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import (
|
||||
GoogleBatchEmbeddings,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
|
||||
VertexMultimodalEmbedding,
|
||||
)
|
||||
|
@ -140,6 +140,9 @@ from .llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||
VertexAIPartnerModels,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings import (
|
||||
embedding_handler as vertex_ai_embedding_handler,
|
||||
)
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.utils import (
|
||||
|
@ -3608,7 +3611,7 @@ def embedding(
|
|||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
else:
|
||||
response = vertex_ai_non_gemini.embedding(
|
||||
response = vertex_ai_embedding_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
|
|
|
@ -2014,6 +2014,25 @@ def test_vertexai_embedding_embedding_latest():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_vertexai_embedding_embedding_latest_input_type():
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
|
||||
response = embedding(
|
||||
model="vertex_ai/text-embedding-004",
|
||||
input=["hi"],
|
||||
input_type="RETRIEVAL_QUERY",
|
||||
)
|
||||
assert response.usage.prompt_tokens > 0
|
||||
print(f"response:", response)
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_vertexai_aembedding():
|
||||
|
|
|
@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd
|
|||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
|
||||
# litellm.num_retries=3
|
||||
# litellm.num_retries = 3
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
user_message = "Write a short poem about the sky"
|
||||
|
|
|
@ -2621,8 +2621,11 @@ def get_optional_params_embeddings(
|
|||
request_type="embeddings",
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params={}
|
||||
(
|
||||
optional_params,
|
||||
kwargs,
|
||||
) = litellm.VertexAITextEmbeddingConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params={}, kwargs=kwargs
|
||||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue