vertex embeddings optional_params

This commit is contained in:
Ishaan Jaff 2024-06-12 10:56:36 -07:00
parent 8eacc99cfc
commit ccfc988d71

View file

@ -4,6 +4,7 @@ from enum import Enum
import requests # type: ignore import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List, Literal, Any from typing import Callable, Optional, Union, List, Literal, Any
from pydantic import BaseModel
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx, inspect # type: ignore import httpx, inspect # type: ignore
@ -1298,20 +1299,45 @@ async def async_streaming(
return streamwrapper return streamwrapper
class VertexAITextEmbeddingConfig: class VertexAITextEmbeddingConfig(BaseModel):
""" """
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args: Args:
auto_truncate: (bool)If True, will truncate input text to fit within the model's max input length. auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
""" """
auto_truncate: Optional[bool] = None auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__( def __init__(
self, self,
auto_truncate: Optional[bool] = None, auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -1416,18 +1442,15 @@ def embedding(
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
""" if optional_params is not None and isinstance(optional_params, dict):
VertexAI supports passing embedding input like this: if optional_params.get("task_type") or optional_params.get("title"):
input=[ # if user passed task_type or title, cast to TextEmbeddingInput
{ _task_type = optional_params.pop("task_type", None)
"text": "good morning from litellm", _title = optional_params.pop("title", None)
"task_type": "RETRIEVAL_DOCUMENT" input = [
} TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
], for x in input
]
In this scenario we cast it to TextEmbeddingInput
"""
input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)]
try: try:
llm_model = TextEmbeddingModel.from_pretrained(model) llm_model = TextEmbeddingModel.from_pretrained(model)