mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
vertex embeddings optional_params
This commit is contained in:
parent
3f1dfc1661
commit
9aa4b6e98c
1 changed files with 38 additions and 15 deletions
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, Union, List, Literal, Any
|
||||
from pydantic import BaseModel
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx, inspect # type: ignore
|
||||
|
@ -1298,20 +1299,45 @@ async def async_streaming(
|
|||
return streamwrapper
|
||||
|
||||
|
||||
class VertexAITextEmbeddingConfig:
|
||||
class VertexAITextEmbeddingConfig(BaseModel):
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||
|
||||
Args:
|
||||
auto_truncate: (bool)If True, will truncate input text to fit within the model's max input length.
|
||||
|
||||
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
|
||||
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
|
||||
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
|
||||
"""
|
||||
|
||||
auto_truncate: Optional[bool] = None
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_truncate: Optional[bool] = None,
|
||||
task_type: Optional[
|
||||
Literal[
|
||||
"RETRIEVAL_QUERY",
|
||||
"RETRIEVAL_DOCUMENT",
|
||||
"SEMANTIC_SIMILARITY",
|
||||
"CLASSIFICATION",
|
||||
"CLUSTERING",
|
||||
"QUESTION_ANSWERING",
|
||||
"FACT_VERIFICATION",
|
||||
]
|
||||
] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
|
@ -1416,18 +1442,15 @@ def embedding(
|
|||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
"""
|
||||
VertexAI supports passing embedding input like this:
|
||||
input=[
|
||||
{
|
||||
"text": "good morning from litellm",
|
||||
"task_type": "RETRIEVAL_DOCUMENT"
|
||||
}
|
||||
],
|
||||
|
||||
In this scenario we cast it to TextEmbeddingInput
|
||||
"""
|
||||
input = [TextEmbeddingInput(**x) for x in input if isinstance(x, dict)]
|
||||
if optional_params is not None and isinstance(optional_params, dict):
|
||||
if optional_params.get("task_type") or optional_params.get("title"):
|
||||
# if user passed task_type or title, cast to TextEmbeddingInput
|
||||
_task_type = optional_params.pop("task_type", None)
|
||||
_title = optional_params.pop("title", None)
|
||||
input = [
|
||||
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
|
||||
for x in input
|
||||
]
|
||||
|
||||
try:
|
||||
llm_model = TextEmbeddingModel.from_pretrained(model)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue