forked from phoenix/litellm-mirror
add VertexAIModelGardenModels
This commit is contained in:
parent
3f8a9167ae
commit
8d28003099
2 changed files with 159 additions and 0 deletions
|
@ -0,0 +1,134 @@
|
|||
# What is this?
|
||||
## API Handler for calling Vertex AI Partner Models
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
def create_vertex_url(
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex partner models"""
|
||||
# f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}"
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
|
||||
|
||||
|
||||
class VertexAIModelGardenModels(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
import vertexai
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
openai_like_chat_completions = DatabricksChatCompletion()
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
stream: bool = optional_params.get("stream", False) or False
|
||||
optional_params["stream"] = stream
|
||||
default_api_base = create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
stream=stream,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=access_token,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
|
@ -158,6 +158,9 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
||||
VertexEmbedding,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
||||
VertexAIModelGardenModels,
|
||||
)
|
||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||
from .types.llms.openai import (
|
||||
|
@ -221,6 +224,7 @@ vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
|||
vertex_image_generation = VertexImageGeneration()
|
||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
|
||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
|
@ -2355,6 +2359,27 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
elif model.isdigit():
|
||||
model_response = vertex_model_garden_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_base=api_base,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
model_response = vertex_ai_non_gemini.completion(
|
||||
model=model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue