From 8d2800309925fb454338a526a5e568acf900c2a9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 15 Nov 2024 10:00:23 -0800 Subject: [PATCH 1/4] add VertexAIModelGardenModels --- .../vertex_model_garden/main.py | 134 ++++++++++++++++++ litellm/main.py | 25 ++++ 2 files changed, 159 insertions(+) create mode 100644 litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py new file mode 100644 index 000000000..59cbddcea --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py @@ -0,0 +1,134 @@ +# What is this? +## API Handler for calling Vertex AI Partner Models +import types +from enum import Enum +from typing import Callable, Literal, Optional, Union + +import httpx # type: ignore + +import litellm +from litellm.utils import ModelResponse + +from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase + + +def create_vertex_url( + vertex_location: str, + vertex_project: str, + stream: Optional[bool], + model: str, + api_base: Optional[str] = None, +) -> str: + """Return the base url for the vertex partner models""" + # f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}" + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}" + + +class VertexAIModelGardenModels(VertexBase): + def __init__(self) -> None: + pass + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + api_base: Optional[str], + optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + logger_fn=None, + acompletion: bool = False, + client=None, + ): + try: + import vertexai + from google.cloud import aiplatform + + from litellm.llms.anthropic.chat import AnthropicChatCompletion + from litellm.llms.databricks.chat import DatabricksChatCompletion + from litellm.llms.OpenAI.openai import OpenAIChatCompletion + from litellm.llms.text_completion_codestral import CodestralTextCompletion + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, + ) + except Exception: + + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + openai_like_chat_completions = DatabricksChatCompletion() + + ## CONSTRUCT API BASE + stream: bool = optional_params.get("stream", False) or False + optional_params["stream"] = stream + default_api_base = create_vertex_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + stream=stream, + model=model, + ) + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=stream, + auth_header=None, + url=default_api_base, + ) + + return openai_like_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + client=client, + timeout=timeout, + encoding=encoding, + custom_llm_provider="vertex_ai", + ) + + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/main.py b/litellm/main.py index 543a93eea..d3bc3f2c5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -158,6 +158,9 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( VertexEmbedding, ) +from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import ( + VertexAIModelGardenModels, +) from .llms.watsonx.chat.handler import WatsonXChatHandler from .llms.watsonx.completion.handler import IBMWatsonXAI from .types.llms.openai import ( @@ -221,6 +224,7 @@ vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_image_generation = VertexImageGeneration() google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() +vertex_model_garden_chat_completion = VertexAIModelGardenModels() vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() sagemaker_llm = SagemakerLLM() @@ -2355,6 +2359,27 @@ def completion( # type: ignore # noqa: PLR0915 api_base=api_base, extra_headers=extra_headers, ) + elif model.isdigit(): + model_response = vertex_model_garden_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + api_base=api_base, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) else: model_response = vertex_ai_non_gemini.completion( model=model, From 34c1dc675aa5d99ef9ec33d0ad47228c2f6d9e36 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 15 Nov 2024 12:36:49 -0800 Subject: [PATCH 2/4] VertexAIModelGardenModels --- .../vertex_model_garden/main.py | 30 ++++++++++++++++--- litellm/main.py | 3 +- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py index 59cbddcea..4285c4dcb 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py @@ -1,5 +1,21 @@ -# What is this? -## API Handler for calling Vertex AI Partner Models +""" +API Handler for calling Vertex AI Model Garden Models + +Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions` + +Usage: + +response = litellm.completion( + model="vertex_ai/openai/5464397967697903616", + messages=[{"role": "user", "content": "Hello, how are you?"}], +) + +Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}` + + +Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb +""" + import types from enum import Enum from typing import Callable, Literal, Optional, Union @@ -20,7 +36,7 @@ def create_vertex_url( model: str, api_base: Optional[str] = None, ) -> str: - """Return the base url for the vertex partner models""" + """Return the base url for the vertex garden models""" # f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}" return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}" @@ -50,6 +66,11 @@ class VertexAIModelGardenModels(VertexBase): acompletion: bool = False, client=None, ): + """ + Handles calling Vertex AI Model Garden Models in OpenAI compatible format + + Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}` + """ try: import vertexai from google.cloud import aiplatform @@ -76,6 +97,7 @@ class VertexAIModelGardenModels(VertexBase): message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", ) try: + model = model.replace("openai/", "") vertex_httpx_logic = VertexLLM() access_token, project_id = vertex_httpx_logic._ensure_access_token( @@ -110,7 +132,7 @@ class VertexAIModelGardenModels(VertexBase): auth_header=None, url=default_api_base, ) - + model = "" return openai_like_chat_completions.completion( model=model, messages=messages, diff --git a/litellm/main.py b/litellm/main.py index d3bc3f2c5..3b4a99413 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2359,7 +2359,8 @@ def completion( # type: ignore # noqa: PLR0915 api_base=api_base, extra_headers=extra_headers, ) - elif model.isdigit(): + elif "openai" in model: + # Vertex Model Garden - OpenAI compatible models model_response = vertex_model_garden_chat_completion.completion( model=model, messages=messages, From 0b8aa778bf6b83b81ff334068054dd8552826552 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 15 Nov 2024 12:45:15 -0800 Subject: [PATCH 3/4] test_vertexai_model_garden_model_completion --- .../test_amazing_vertex_completion.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 5a07d17b7..e8fb67478 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -3123,3 +3123,85 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter): assert isinstance(embedding["embedding"], list) assert len(embedding["embedding"]) > 0 assert all(isinstance(x, float) for x in embedding["embedding"]) + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter): + """ + Relevant issue: https://github.com/BerriAI/litellm/issues/6480 + + Using OpenAI compatible models from Vertex Model Garden + """ + load_vertex_ai_credentials() + litellm.set_verbose = True + + # Test input + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + # Expected request/response + expected_url = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/633608382793/locations/us-central1/endpoints/5464397967697903616/chat/completions" + expected_request = {"model": "", "messages": messages, "stream": False} + + mock_response = { + "id": "chat-09940d4e99e3488aa52a6f5e2ecf35b1", + "object": "chat.completion", + "created": 1731702782, + "model": "meta-llama/Llama-3.1-8B-Instruct", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello, my name is Litellm Bot. I'm a helpful assistant here to provide information and answer your questions.\n\nTo check the weather for you, I'll need to know your location. Could you please provide me with your city or zip code? That way, I can give you the most accurate and up-to-date weather information.\n\nIf you don't have your location handy, I can also suggest some popular weather websites or apps that you can use to check the weather for your area.\n\nLet me know how I can assist you!", + "tool_calls": [], + }, + "logprobs": None, + "finish_reason": "stop", + "stop_reason": None, + } + ], + "usage": {"prompt_tokens": 63, "total_tokens": 172, "completion_tokens": 109}, + "prompt_logprobs": None, + } + + # Setup mock request + mock_request = respx_mock.post(expected_url).mock( + return_value=httpx.Response(200, json=mock_response) + ) + + # Make request + response = await litellm.acompletion( + model="vertex_ai/openai/5464397967697903616", + messages=messages, + vertex_project="633608382793", + vertex_location="us-central1", + ) + + # Assert request was made correctly + assert mock_request.called + request_body = json.loads(mock_request.calls[0].request.content) + assert request_body == expected_request + + # Assert response structure + assert response.id == "chat-09940d4e99e3488aa52a6f5e2ecf35b1" + assert response.created == 1731702782 + assert response.model == "vertex_ai/meta-llama/Llama-3.1-8B-Instruct" + assert len(response.choices) == 1 + assert response.choices[0].message.role == "assistant" + assert response.choices[0].message.content.startswith( + "Hello, my name is Litellm Bot" + ) + assert response.choices[0].finish_reason == "stop" + assert response.usage.completion_tokens == 109 + assert response.usage.prompt_tokens == 63 + assert response.usage.total_tokens == 172 From 969dad19a10a011ff90130a419a5332a66dfd4a2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 15 Nov 2024 12:54:40 -0800 Subject: [PATCH 4/4] docs model garden --- docs/my-website/docs/providers/vertex.md | 95 +++++++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 921db9e73..605762422 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1161,12 +1161,96 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ ## Model Garden -| Model Name | Function Call | -|------------------|--------------------------------------| -| llama2 | `completion('vertex_ai/', messages)` | + +:::tip + +All OpenAI compatible models from Vertex Model Garden are supported. + +::: #### Using Model Garden +**Almost all Vertex Model Garden models are OpenAI compatible.** + + + + + +| Property | Details | +|----------|---------| +| Provider Route | `vertex_ai/openai/{MODEL_ID}` | +| Vertex Documentation | [Vertex Model Garden - OpenAI Chat Completions](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gradio_streaming_chat_completions.ipynb), [Vertex Model Garden](https://cloud.google.com/model-garden?hl=en) | +| Supported Operations | `/chat/completions`, `/embeddings` | + + + + +```python +from litellm import completion +import os + +## set ENV variables +os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811" +os.environ["VERTEXAI_LOCATION"] = "us-central1" + +response = completion( + model="vertex_ai/openai/", + messages=[{ "content": "Hello, how are you?","role": "user"}] +) +``` + + + + + + +**1. Add to config** + +```yaml +model_list: + - model_name: llama3-1-8b-instruct + litellm_params: + model: vertex_ai/openai/5464397967697903616 + vertex_ai_project: "my-test-project" + vertex_ai_location: "us-east-1" +``` + +**2. Start proxy** + +```bash +litellm --config /path/to/config.yaml + +# RUNNING at http://0.0.0.0:4000 +``` + +**3. Test it!** + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "llama3-1-8b-instruct", # 👈 the 'model_name' in config + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' +``` + + + + + + + + + + + + ```python from litellm import completion import os @@ -1181,6 +1265,11 @@ response = completion( ) ``` + + + + + ## Gemini Pro | Model Name | Function Call | |------------------|--------------------------------------|