vertex embeddings optional_params

This commit is contained in:
Ishaan Jaff 2024-06-12 10:56:36 -07:00
parent 8eacc99cfc
commit ccfc988d71

View file

@ -4,6 +4,7 @@ from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List, Literal, Any
from pydantic import BaseModel
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
@ -1298,20 +1299,45 @@ async def async_streaming(
return streamwrapper
class VertexAITextEmbeddingConfig:
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: (bool)If True, will truncate input text to fit within the model's max input length.
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
@ -1416,18 +1442,15 @@ def embedding(
if isinstance(input, str):
input = [input]
"""
VertexAI supports passing embedding input like this:
input=[
{
"text": "good morning from litellm",
"task_type": "RETRIEVAL_DOCUMENT"
}
],
In this scenario we cast it to TextEmbeddingInput
"""
input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)