Merge branch 'main' into litellm_vertex_completion_httpx

This commit is contained in:
Krish Dholakia 2024-06-12 21:19:22 -07:00 committed by GitHub
commit 05e21441a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 568 additions and 145 deletions

View file

@ -4,6 +4,7 @@ from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List, Literal, Any
from pydantic import BaseModel
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
@ -1298,6 +1299,95 @@ async def async_streaming(
return streamwrapper
class VertexAITextEmbeddingConfig(BaseModel):
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
"""
auto_truncate: Optional[bool] = None
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None
title: Optional[str] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
title: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"dimensions",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def embedding(
model: str,
input: Union[list, str],
@ -1321,7 +1411,7 @@ def embedding(
message="vertexai import failed please run `pip install google-cloud-aiplatform`",
)
from vertexai.language_models import TextEmbeddingModel
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput
import google.auth # type: ignore
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
@ -1352,6 +1442,16 @@ def embedding(
if isinstance(input, str):
input = [input]
if optional_params is not None and isinstance(optional_params, dict):
if optional_params.get("task_type") or optional_params.get("title"):
# if user passed task_type or title, cast to TextEmbeddingInput
_task_type = optional_params.pop("task_type", None)
_title = optional_params.pop("title", None)
input = [
TextEmbeddingInput(text=x, task_type=_task_type, title=_title)
for x in input
]
try:
llm_model = TextEmbeddingModel.from_pretrained(model)
except Exception as e:
@ -1368,7 +1468,8 @@ def embedding(
encoding=encoding,
)
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
_input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
@ -1380,7 +1481,7 @@ def embedding(
)
try:
embeddings = llm_model.get_embeddings(input)
embeddings = llm_model.get_embeddings(**_input_dict)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
@ -1388,6 +1489,7 @@ def embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
@ -1396,14 +1498,10 @@ def embedding(
"embedding": embedding.values,
}
)
input_tokens += embedding.statistics.token_count
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
@ -1425,7 +1523,8 @@ async def async_embedding(
"""
Async embedding implementation
"""
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
_input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL
logging_obj.pre_call(
input=input,
@ -1437,7 +1536,7 @@ async def async_embedding(
)
try:
embeddings = await client.get_embeddings_async(input)
embeddings = await client.get_embeddings_async(**_input_dict)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
@ -1445,6 +1544,7 @@ async def async_embedding(
logging_obj.post_call(input=input, api_key=None, original_response=embeddings)
## Populate OpenAI compliant dictionary
embedding_response = []
input_tokens: int = 0
for idx, embedding in enumerate(embeddings):
embedding_response.append(
{
@ -1453,18 +1553,13 @@ async def async_embedding(
"embedding": embedding.values,
}
)
input_tokens += embedding.statistics.token_count
model_response["object"] = "list"
model_response["data"] = embedding_response
model_response["model"] = model
input_tokens = 0
input_str = "".join(input)
input_tokens += len(encoding.encode(input_str))
usage = Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
)
model_response.usage = usage
return model_response