mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat - support vertex ai dimensions
This commit is contained in:
parent
dd22f6aca0
commit
e4b36d71cf
3 changed files with 97 additions and 11 deletions
|
@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import (
|
|||
convert_to_gemini_tool_call_result,
|
||||
convert_to_gemini_tool_call_invoke,
|
||||
)
|
||||
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type
|
||||
from litellm.types.files import (
|
||||
get_file_mime_type_for_file_type,
|
||||
get_file_type_from_extension,
|
||||
is_gemini_1_5_accepted_file_type,
|
||||
is_video_file_type,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -301,15 +306,15 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
# GCS URIs
|
||||
if "gs://" in image_url:
|
||||
# Figure out file type
|
||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||
extension = extension_with_dot[1:] # Ex: "png"
|
||||
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
|
||||
extension = extension_with_dot[1:] # Ex: "png"
|
||||
|
||||
file_type = get_file_type_from_extension(extension)
|
||||
|
||||
# Validate the file type is supported by Gemini
|
||||
if not is_gemini_1_5_accepted_file_type(file_type):
|
||||
raise Exception(f"File type not supported by gemini - {file_type}")
|
||||
|
||||
|
||||
mime_type = get_file_mime_type_for_file_type(file_type)
|
||||
file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
|
||||
|
||||
|
@ -320,7 +325,7 @@ def _process_gemini_image(image_url: str) -> PartType:
|
|||
image = _load_image_from_url(image_url)
|
||||
_blob = BlobType(data=image.data, mime_type=image._mime_type)
|
||||
return PartType(inline_data=_blob)
|
||||
|
||||
|
||||
# Base64 encoding
|
||||
elif "base64" in image_url:
|
||||
import base64, re
|
||||
|
@ -1293,6 +1298,70 @@ async def async_streaming(
|
|||
return streamwrapper
|
||||
|
||||
|
||||
class VertexAITextEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
|
||||
|
||||
Args:
|
||||
auto_truncate: (bool)If True, will truncate input text to fit within the model's max input length.
|
||||
|
||||
"""
|
||||
|
||||
auto_truncate: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_truncate: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"dimensions",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "dimensions":
|
||||
optional_params["output_dimensionality"] = value
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
|
@ -1363,7 +1432,8 @@ def embedding(
|
|||
encoding=encoding,
|
||||
)
|
||||
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
||||
_input_dict = {"texts": input, **optional_params}
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||
## LOGGING PRE-CALL
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -1375,7 +1445,7 @@ def embedding(
|
|||
)
|
||||
|
||||
try:
|
||||
embeddings = llm_model.get_embeddings(input)
|
||||
embeddings = llm_model.get_embeddings(**_input_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
@ -1420,7 +1490,8 @@ async def async_embedding(
|
|||
"""
|
||||
Async embedding implementation
|
||||
"""
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({input})"""
|
||||
_input_dict = {"texts": input, **optional_params}
|
||||
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
|
||||
## LOGGING PRE-CALL
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -1432,7 +1503,7 @@ async def async_embedding(
|
|||
)
|
||||
|
||||
try:
|
||||
embeddings = await client.get_embeddings_async(input)
|
||||
embeddings = await client.get_embeddings_async(**_input_dict)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue