VertexAIModelGardenModels

This commit is contained in:
Ishaan Jaff 2024-11-15 12:36:49 -08:00
parent 8d28003099
commit 34c1dc675a
2 changed files with 28 additions and 5 deletions

View file

@ -1,5 +1,21 @@
# What is this?
## API Handler for calling Vertex AI Partner Models
"""
API Handler for calling Vertex AI Model Garden Models
Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions`
Usage:
response = litellm.completion(
model="vertex_ai/openai/5464397967697903616",
messages=[{"role": "user", "content": "Hello, how are you?"}],
)
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb
"""
import types
from enum import Enum
from typing import Callable, Literal, Optional, Union
@ -20,7 +36,7 @@ def create_vertex_url(
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex partner models"""
"""Return the base url for the vertex garden models"""
# f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}"
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
@ -50,6 +66,11 @@ class VertexAIModelGardenModels(VertexBase):
acompletion: bool = False,
client=None,
):
"""
Handles calling Vertex AI Model Garden Models in OpenAI compatible format
Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
"""
try:
import vertexai
from google.cloud import aiplatform
@ -76,6 +97,7 @@ class VertexAIModelGardenModels(VertexBase):
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
model = model.replace("openai/", "")
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
@ -110,7 +132,7 @@ class VertexAIModelGardenModels(VertexBase):
auth_header=None,
url=default_api_base,
)
model = ""
return openai_like_chat_completions.completion(
model=model,
messages=messages,

View file

@ -2359,7 +2359,8 @@ def completion( # type: ignore # noqa: PLR0915
api_base=api_base,
extra_headers=extra_headers,
)
elif model.isdigit():
elif "openai" in model:
# Vertex Model Garden - OpenAI compatible models
model_response = vertex_model_garden_chat_completion.completion(
model=model,
messages=messages,