diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py new file mode 100644 index 000000000..59cbddcea --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py @@ -0,0 +1,134 @@ +# What is this? +## API Handler for calling Vertex AI Partner Models +import types +from enum import Enum +from typing import Callable, Literal, Optional, Union + +import httpx # type: ignore + +import litellm +from litellm.utils import ModelResponse + +from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase + + +def create_vertex_url( + vertex_location: str, + vertex_project: str, + stream: Optional[bool], + model: str, + api_base: Optional[str] = None, +) -> str: + """Return the base url for the vertex partner models""" + # f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}" + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}" + + +class VertexAIModelGardenModels(VertexBase): + def __init__(self) -> None: + pass + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + api_base: Optional[str], + optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + logger_fn=None, + acompletion: bool = False, + client=None, + ): + try: + import vertexai + from google.cloud import aiplatform + + from litellm.llms.anthropic.chat import AnthropicChatCompletion + from litellm.llms.databricks.chat import DatabricksChatCompletion + from litellm.llms.OpenAI.openai import OpenAIChatCompletion + from litellm.llms.text_completion_codestral import CodestralTextCompletion + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, + ) + except Exception: + + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + openai_like_chat_completions = DatabricksChatCompletion() + + ## CONSTRUCT API BASE + stream: bool = optional_params.get("stream", False) or False + optional_params["stream"] = stream + default_api_base = create_vertex_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + stream=stream, + model=model, + ) + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=stream, + auth_header=None, + url=default_api_base, + ) + + return openai_like_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + client=client, + timeout=timeout, + encoding=encoding, + custom_llm_provider="vertex_ai", + ) + + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/main.py b/litellm/main.py index 543a93eea..d3bc3f2c5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -158,6 +158,9 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( VertexEmbedding, ) +from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import ( + VertexAIModelGardenModels, +) from .llms.watsonx.chat.handler import WatsonXChatHandler from .llms.watsonx.completion.handler import IBMWatsonXAI from .types.llms.openai import ( @@ -221,6 +224,7 @@ vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_image_generation = VertexImageGeneration() google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() +vertex_model_garden_chat_completion = VertexAIModelGardenModels() vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() sagemaker_llm = SagemakerLLM() @@ -2355,6 +2359,27 @@ def completion( # type: ignore # noqa: PLR0915 api_base=api_base, extra_headers=extra_headers, ) + elif model.isdigit(): + model_response = vertex_model_garden_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + api_base=api_base, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) else: model_response = vertex_ai_non_gemini.completion( model=model,