Merge pull request #4152 from BerriAI/litellm_support_vertex_text_input

[Feat] Support `task_type`, `auto_truncate` params
This commit is contained in:
Ishaan Jaff 2024-06-12 13:25:45 -07:00 committed by GitHub
commit 3254cf50b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 170 additions and 25 deletions

View file

@ -558,6 +558,29 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` | | text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` | | text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |
### Advanced Use `task_type` and `title` (Vertex Specific Params)
👉 `task_type` and `title` are vertex specific params
LiteLLM Supported Vertex Specific Params
```python
auto_truncate: Optional[bool] = None
task_type: Optional[Literal["RETRIEVAL_QUERY","RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", "QUESTION_ANSWERING", "FACT_VERIFICATION"]] = None
title: Optional[str] = None # The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
```
**Example Usage with LiteLLM**
```python
response = litellm.embedding(
model="vertex_ai/text-embedding-004",
input=["good morning from litellm", "gm"]
task_type = "RETRIEVAL_DOCUMENT",
dimensions=1,
auto_truncate=True,
)
```
## Image Generation Models ## Image Generation Models
Usage Usage

View file

@ -765,7 +765,7 @@ from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig

View file

@ -4,6 +4,7 @@ from enum import Enum
import requests # type: ignore import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List, Literal, Any from typing import Callable, Optional, Union, List, Literal, Any
from pydantic import BaseModel
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx, inspect # type: ignore import httpx, inspect # type: ignore
@ -12,7 +13,12 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_invoke,
) )
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
class VertexAIError(Exception): class VertexAIError(Exception):
@ -1293,6 +1299,95 @@ async def async_streaming(
return streamwrapper return streamwrapper
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"dimensions",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def embedding( def embedding(
model: str, model: str,
input: Union[list, str], input: Union[list, str],
@ -1316,7 +1411,7 @@ def embedding(
message="vertexai import failed please run `pip install google-cloud-aiplatform`", message="vertexai import failed please run `pip install google-cloud-aiplatform`",
) )
from vertexai.language_models import TextEmbeddingModel from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
import google.auth # type: ignore import google.auth # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
@ -1347,6 +1442,16 @@ def embedding(
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try: try:
llm_model = TextEmbeddingModel.from_pretrained(model) llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e: except Exception as e:
@ -1363,7 +1468,8 @@ def embedding(
encoding=encoding, encoding=encoding,
) )
request_str = f"""embeddings = llm_model.get_embeddings({input})""" _input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL ## LOGGING PRE-CALL
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1375,7 +1481,7 @@ def embedding(
) )
try: try:
embeddings = llm_model.get_embeddings(input) embeddings = llm_model.get_embeddings(**_input_dict)
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
@ -1383,6 +1489,7 @@ def embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
embedding_response.append( embedding_response.append(
{ {
@ -1391,14 +1498,10 @@ def embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
@ -1420,7 +1523,8 @@ async def async_embedding(
""" """
Async embedding implementation Async embedding implementation
""" """
request_str = f"""embeddings = llm_model.get_embeddings({input})""" _input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL ## LOGGING PRE-CALL
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1432,7 +1536,7 @@ async def async_embedding(
) )
try: try:
embeddings = await client.get_embeddings_async(input) embeddings = await client.get_embeddings_async(**_input_dict)
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
@ -1440,6 +1544,7 @@ async def async_embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings) logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary ## Populate OpenAI compliant dictionary
embedding_response = [] embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):
embedding_response.append( embedding_response.append(
{ {
@ -1448,18 +1553,13 @@ async def async_embedding(
"embedding": embedding.values, "embedding": embedding.values,
} }
) )
input_tokens += embedding.statistics.token_count
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = embedding_response model_response["data"] = embedding_response
model_response["model"] = model model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage( usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
) )
model_response.usage = usage model_response.usage = usage
return model_response return model_response

View file

@ -814,10 +814,17 @@ def test_vertexai_embedding_embedding_latest():
try: try:
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
response = embedding( response = embedding(
model="vertex_ai/text-embedding-004", model="vertex_ai/text-embedding-004",
input=["good morning from litellm", "this is another item"], input=["hi"],
dimensions=1,
auto_truncate=True,
task_type="RETRIEVAL_QUERY",
) )
assert len(response.data[0]["embedding"]) == 1
assert response.usage.prompt_tokens > 0
print(f"response:", response) print(f"response:", response)
except litellm.RateLimitError as e: except litellm.RateLimitError as e:
pass pass

View file

@ -4898,6 +4898,18 @@ def get_optional_params_embeddings(
) )
final_params = {**optional_params, **kwargs} final_params = {**optional_params, **kwargs}
return final_params return final_params
if custom_llm_provider == "vertex_ai":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="vertex_ai",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":
if len(non_default_params.keys()) > 0: if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values if litellm.drop_params is True: # drop the unsupported non-default values
@ -6382,7 +6394,10 @@ def get_supported_openai_params(
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
if request_type == "chat_completion":
return litellm.VertexAIConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":