forked from phoenix/litellm-mirror
Merge pull request #4152 from BerriAI/litellm_support_vertex_text_input
[Feat] Support `task_type`, `auto_truncate` params
This commit is contained in:
commit
3254cf50b7
5 changed files with 170 additions and 25 deletions
|
@ -558,6 +558,29 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
|
||||||
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
|
||||||
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
|
||||||
|
|
||||||
|
### Advanced Use `task_type` and `title` (Vertex Specific Params)
|
||||||
|
|
||||||
|
👉 `task_type` and `title` are vertex specific params
|
||||||
|
|
||||||
|
LiteLLM Supported Vertex Specific Params
|
||||||
|
|
||||||
|
```python
|
||||||
|
auto_truncate: Optional[bool] = None
|
||||||
|
task_type: Optional[Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]] = None
|
||||||
|
title: Optional[str] = None # The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Usage with LiteLLM**
|
||||||
|
```python
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="vertex_ai/text-embedding-004",
|
||||||
|
input=["good morning from litellm", "gm"]
|
||||||
|
task_type = "RETRIEVAL_DOCUMENT",
|
||||||
|
dimensions=1,
|
||||||
|
auto_truncate=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Image Generation Models
|
## Image Generation Models
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
|
|
|
@ -765,7 +765,7 @@ from .llms.gemini import GeminiConfig
|
||||||
from .llms.nlp_cloud import NLPCloudConfig
|
from .llms.nlp_cloud import NLPCloudConfig
|
||||||
from .llms.aleph_alpha import AlephAlphaConfig
|
from .llms.aleph_alpha import AlephAlphaConfig
|
||||||
from .llms.petals import PetalsConfig
|
from .llms.petals import PetalsConfig
|
||||||
from .llms.vertex_ai import VertexAIConfig
|
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
||||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||||
from .llms.sagemaker import SagemakerConfig
|
from .llms.sagemaker import SagemakerConfig
|
||||||
from .llms.ollama import OllamaConfig
|
from .llms.ollama import OllamaConfig
|
||||||
|
|
|
@ -4,6 +4,7 @@ from enum import Enum
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union, List, Literal, Any
|
from typing import Callable, Optional, Union, List, Literal, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx, inspect # type: ignore
|
import httpx, inspect # type: ignore
|
||||||
|
@ -12,7 +13,12 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
convert_to_gemini_tool_call_result,
|
convert_to_gemini_tool_call_result,
|
||||||
convert_to_gemini_tool_call_invoke,
|
convert_to_gemini_tool_call_invoke,
|
||||||
)
|
)
|
||||||
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
from litellm.types.files import (
|
||||||
|
get_file_mime_type_for_file_type,
|
||||||
|
get_file_type_from_extension,
|
||||||
|
is_gemini_1_5_accepted_file_type,
|
||||||
|
is_video_file_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -301,8 +307,8 @@ def _process_gemini_image(image_url: str) -> PartType:
|
||||||
# GCS URIs
|
# GCS URIs
|
||||||
if "gs://" in image_url:
|
if "gs://" in image_url:
|
||||||
# Figure out file type
|
# Figure out file type
|
||||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||||
extension = extension_with_dot[1:] # Ex: "png"
|
extension = extension_with_dot[1:] # Ex: "png"
|
||||||
|
|
||||||
file_type = get_file_type_from_extension(extension)
|
file_type = get_file_type_from_extension(extension)
|
||||||
|
|
||||||
|
@ -1293,6 +1299,95 @@ async def async_streaming(
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAITextEmbeddingConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
|
||||||
|
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
|
||||||
|
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
auto_truncate: Optional[bool] = None
|
||||||
|
task_type: Optional[
|
||||||
|
Literal[
|
||||||
|
"RETRIEVAL_QUERY",
|
||||||
|
"RETRIEVAL_DOCUMENT",
|
||||||
|
"SEMANTIC_SIMILARITY",
|
||||||
|
"CLASSIFICATION",
|
||||||
|
"CLUSTERING",
|
||||||
|
"QUESTION_ANSWERING",
|
||||||
|
"FACT_VERIFICATION",
|
||||||
|
]
|
||||||
|
] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
auto_truncate: Optional[bool] = None,
|
||||||
|
task_type: Optional[
|
||||||
|
Literal[
|
||||||
|
"RETRIEVAL_QUERY",
|
||||||
|
"RETRIEVAL_DOCUMENT",
|
||||||
|
"SEMANTIC_SIMILARITY",
|
||||||
|
"CLASSIFICATION",
|
||||||
|
"CLUSTERING",
|
||||||
|
"QUESTION_ANSWERING",
|
||||||
|
"FACT_VERIFICATION",
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"dimensions",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "dimensions":
|
||||||
|
optional_params["output_dimensionality"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[list, str],
|
input: Union[list, str],
|
||||||
|
@ -1316,7 +1411,7 @@ def embedding(
|
||||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
|
||||||
)
|
)
|
||||||
|
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
|
@ -1347,6 +1442,16 @@ def embedding(
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
|
if optional_params is not None and isinstance(optional_params, dict):
|
||||||
|
if optional_params.get("task_type") or optional_params.get("title"):
|
||||||
|
# if user passed task_type or title, cast to TextEmbeddingInput
|
||||||
|
_task_type = optional_params.pop("task_type", None)
|
||||||
|
_title = optional_params.pop("title", None)
|
||||||
|
input = [
|
||||||
|
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
|
||||||
|
for x in input
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1363,7 +1468,8 @@ def embedding(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
_input_dict = {"texts": input, **optional_params}
|
||||||
|
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||||
## LOGGING PRE-CALL
|
## LOGGING PRE-CALL
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1375,7 +1481,7 @@ def embedding(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embeddings = llm_model.get_embeddings(input)
|
embeddings = llm_model.get_embeddings(**_input_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
@ -1383,6 +1489,7 @@ def embedding(
|
||||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
input_tokens: int = 0
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{
|
{
|
||||||
|
@ -1391,14 +1498,10 @@ def embedding(
|
||||||
"embedding": embedding.values,
|
"embedding": embedding.values,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_tokens += embedding.statistics.token_count
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
|
||||||
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
@ -1420,7 +1523,8 @@ async def async_embedding(
|
||||||
"""
|
"""
|
||||||
Async embedding implementation
|
Async embedding implementation
|
||||||
"""
|
"""
|
||||||
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
_input_dict = {"texts": input, **optional_params}
|
||||||
|
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||||
## LOGGING PRE-CALL
|
## LOGGING PRE-CALL
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -1432,7 +1536,7 @@ async def async_embedding(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
embeddings = await client.get_embeddings_async(input)
|
embeddings = await client.get_embeddings_async(**_input_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
@ -1440,6 +1544,7 @@ async def async_embedding(
|
||||||
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
|
||||||
## Populate OpenAI compliant dictionary
|
## Populate OpenAI compliant dictionary
|
||||||
embedding_response = []
|
embedding_response = []
|
||||||
|
input_tokens: int = 0
|
||||||
for idx, embedding in enumerate(embeddings):
|
for idx, embedding in enumerate(embeddings):
|
||||||
embedding_response.append(
|
embedding_response.append(
|
||||||
{
|
{
|
||||||
|
@ -1448,18 +1553,13 @@ async def async_embedding(
|
||||||
"embedding": embedding.values,
|
"embedding": embedding.values,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
input_tokens += embedding.statistics.token_count
|
||||||
|
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = embedding_response
|
model_response["data"] = embedding_response
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
input_tokens = 0
|
|
||||||
|
|
||||||
input_str = "".join(input)
|
|
||||||
|
|
||||||
input_tokens += len(encoding.encode(input_str))
|
|
||||||
|
|
||||||
usage = Usage(
|
usage = Usage(
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -814,10 +814,17 @@ def test_vertexai_embedding_embedding_latest():
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
response = embedding(
|
response = embedding(
|
||||||
model="vertex_ai/text-embedding-004",
|
model="vertex_ai/text-embedding-004",
|
||||||
input=["good morning from litellm", "this is another item"],
|
input=["hi"],
|
||||||
|
dimensions=1,
|
||||||
|
auto_truncate=True,
|
||||||
|
task_type="RETRIEVAL_QUERY",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(response.data[0]["embedding"]) == 1
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
print(f"response:", response)
|
print(f"response:", response)
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4898,6 +4898,18 @@ def get_optional_params_embeddings(
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
request_type="embeddings",
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params, optional_params={}
|
||||||
|
)
|
||||||
|
final_params = {**optional_params, **kwargs}
|
||||||
|
return final_params
|
||||||
if custom_llm_provider == "vertex_ai":
|
if custom_llm_provider == "vertex_ai":
|
||||||
if len(non_default_params.keys()) > 0:
|
if len(non_default_params.keys()) > 0:
|
||||||
if litellm.drop_params is True: # drop the unsupported non-default values
|
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||||
|
@ -6382,7 +6394,10 @@ def get_supported_openai_params(
|
||||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||||
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
|
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
if request_type == "chat_completion":
|
||||||
|
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||||
|
elif request_type == "embeddings":
|
||||||
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
elif custom_llm_provider == "aleph_alpha":
|
elif custom_llm_provider == "aleph_alpha":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue