feat - support vertex ai dimensions

This commit is contained in:
Ishaan Jaff 2024-06-12 09:29:51 -07:00
parent dd22f6aca0
commit e4b36d71cf
3 changed files with 97 additions and 11 deletions

View file

@ -765,7 +765,7 @@ from .llms.gemini import GeminiConfig
from .llms.nlp_cloud import NLPCloudConfig from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig

View file

@ -12,7 +12,12 @@ from litellm.llms.prompt_templates.factory import (
convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_result,
convert_to_gemini_tool_call_invoke, convert_to_gemini_tool_call_invoke,
) )
from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type from litellm.types.files import (
get_file_mime_type_for_file_type,
get_file_type_from_extension,
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
class VertexAIError(Exception): class VertexAIError(Exception):
@ -301,15 +306,15 @@ def _process_gemini_image(image_url: str) -> PartType:
# GCS URIs # GCS URIs
if "gs://" in image_url: if "gs://" in image_url:
# Figure out file type # Figure out file type
extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
extension = extension_with_dot[1:] # Ex: "png" extension = extension_with_dot[1:] # Ex: "png"
file_type = get_file_type_from_extension(extension) file_type = get_file_type_from_extension(extension)
# Validate the file type is supported by Gemini # Validate the file type is supported by Gemini
if not is_gemini_1_5_accepted_file_type(file_type): if not is_gemini_1_5_accepted_file_type(file_type):
raise Exception(f"File type not supported by gemini - {file_type}") raise Exception(f"File type not supported by gemini - {file_type}")
mime_type = get_file_mime_type_for_file_type(file_type) mime_type = get_file_mime_type_for_file_type(file_type)
file_data = FileDataType(mime_type=mime_type, file_uri=image_url) file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
@ -320,7 +325,7 @@ def _process_gemini_image(image_url: str) -> PartType:
image = _load_image_from_url(image_url) image = _load_image_from_url(image_url)
_blob = BlobType(data=image.data, mime_type=image._mime_type) _blob = BlobType(data=image.data, mime_type=image._mime_type)
return PartType(inline_data=_blob) return PartType(inline_data=_blob)
# Base64 encoding # Base64 encoding
elif "base64" in image_url: elif "base64" in image_url:
import base64, re import base64, re
@ -1293,6 +1298,70 @@ async def async_streaming(
return streamwrapper return streamwrapper
class VertexAITextEmbeddingConfig:
"""
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
Args:
auto_truncate: (bool)If True, will truncate input text to fit within the model's max input length.
"""
auto_truncate: Optional[bool] = None
def __init__(
self,
auto_truncate: Optional[bool] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"dimensions",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "dimensions":
optional_params["output_dimensionality"] = value
return optional_params
def get_mapped_special_auth_params(self) -> dict:
"""
Common auth params across bedrock/vertex_ai/azure/watsonx
"""
return {"project": "vertex_project", "region_name": "vertex_location"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def embedding( def embedding(
model: str, model: str,
input: Union[list, str], input: Union[list, str],
@ -1363,7 +1432,8 @@ def embedding(
encoding=encoding, encoding=encoding,
) )
request_str = f"""embeddings = llm_model.get_embeddings({input})""" _input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL ## LOGGING PRE-CALL
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1375,7 +1445,7 @@ def embedding(
) )
try: try:
embeddings = llm_model.get_embeddings(input) embeddings = llm_model.get_embeddings(**_input_dict)
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
@ -1420,7 +1490,8 @@ async def async_embedding(
""" """
Async embedding implementation Async embedding implementation
""" """
request_str = f"""embeddings = llm_model.get_embeddings({input})""" _input_dict = {"texts": input, **optional_params}
request_str = f"""embeddings = llm_model.get_embeddings({_input_dict})"""
## LOGGING PRE-CALL ## LOGGING PRE-CALL
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,
@ -1432,7 +1503,7 @@ async def async_embedding(
) )
try: try:
embeddings = await client.get_embeddings_async(input) embeddings = await client.get_embeddings_async(**_input_dict)
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))

View file

@ -4898,6 +4898,18 @@ def get_optional_params_embeddings(
) )
final_params = {**optional_params, **kwargs} final_params = {**optional_params, **kwargs}
return final_params return final_params
if custom_llm_provider == "vertex_ai":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="vertex_ai",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai": if custom_llm_provider == "vertex_ai":
if len(non_default_params.keys()) > 0: if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values if litellm.drop_params is True: # drop the unsupported non-default values
@ -6382,7 +6394,10 @@ def get_supported_openai_params(
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"] return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
return litellm.VertexAIConfig().get_supported_openai_params() if request_type == "chat_completion":
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha": elif custom_llm_provider == "aleph_alpha":