diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 921db9e73..605762422 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -1161,12 +1161,96 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ ## Model Garden -| Model Name | Function Call | -|------------------|--------------------------------------| -| llama2 | `completion('vertex_ai/', messages)` | + +:::tip + +All OpenAI compatible models from Vertex Model Garden are supported. + +::: #### Using Model Garden +**Almost all Vertex Model Garden models are OpenAI compatible.** + + + + + +| Property | Details | +|----------|---------| +| Provider Route | `vertex_ai/openai/{MODEL_ID}` | +| Vertex Documentation | [Vertex Model Garden - OpenAI Chat Completions](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gradio_streaming_chat_completions.ipynb), [Vertex Model Garden](https://cloud.google.com/model-garden?hl=en) | +| Supported Operations | `/chat/completions`, `/embeddings` | + + + + +```python +from litellm import completion +import os + +## set ENV variables +os.environ["VERTEXAI_PROJECT"] = "hardy-device-38811" +os.environ["VERTEXAI_LOCATION"] = "us-central1" + +response = completion( + model="vertex_ai/openai/", + messages=[{ "content": "Hello, how are you?","role": "user"}] +) +``` + + + + + + +**1. Add to config** + +```yaml +model_list: + - model_name: llama3-1-8b-instruct + litellm_params: + model: vertex_ai/openai/5464397967697903616 + vertex_ai_project: "my-test-project" + vertex_ai_location: "us-east-1" +``` + +**2. Start proxy** + +```bash +litellm --config /path/to/config.yaml + +# RUNNING at http://0.0.0.0:4000 +``` + +**3. Test it!** + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "llama3-1-8b-instruct", # 👈 the 'model_name' in config + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' +``` + + + + + + + + + + + + ```python from litellm import completion import os @@ -1181,6 +1265,11 @@ response = completion( ) ``` + + + + + ## Gemini Pro | Model Name | Function Call | |------------------|--------------------------------------| diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py new file mode 100644 index 000000000..4285c4dcb --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_model_garden/main.py @@ -0,0 +1,156 @@ +""" +API Handler for calling Vertex AI Model Garden Models + +Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions` + +Usage: + +response = litellm.completion( + model="vertex_ai/openai/5464397967697903616", + messages=[{"role": "user", "content": "Hello, how are you?"}], +) + +Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}` + + +Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb +""" + +import types +from enum import Enum +from typing import Callable, Literal, Optional, Union + +import httpx # type: ignore + +import litellm +from litellm.utils import ModelResponse + +from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase + + +def create_vertex_url( + vertex_location: str, + vertex_project: str, + stream: Optional[bool], + model: str, + api_base: Optional[str] = None, +) -> str: + """Return the base url for the vertex garden models""" + # f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}" + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}" + + +class VertexAIModelGardenModels(VertexBase): + def __init__(self) -> None: + pass + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + api_base: Optional[str], + optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + logger_fn=None, + acompletion: bool = False, + client=None, + ): + """ + Handles calling Vertex AI Model Garden Models in OpenAI compatible format + + Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}` + """ + try: + import vertexai + from google.cloud import aiplatform + + from litellm.llms.anthropic.chat import AnthropicChatCompletion + from litellm.llms.databricks.chat import DatabricksChatCompletion + from litellm.llms.OpenAI.openai import OpenAIChatCompletion + from litellm.llms.text_completion_codestral import CodestralTextCompletion + from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, + ) + except Exception: + + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + model = model.replace("openai/", "") + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + openai_like_chat_completions = DatabricksChatCompletion() + + ## CONSTRUCT API BASE + stream: bool = optional_params.get("stream", False) or False + optional_params["stream"] = stream + default_api_base = create_vertex_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + stream=stream, + model=model, + ) + + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=stream, + auth_header=None, + url=default_api_base, + ) + model = "" + return openai_like_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + client=client, + timeout=timeout, + encoding=encoding, + custom_llm_provider="vertex_ai", + ) + + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/main.py b/litellm/main.py index 543a93eea..3b4a99413 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -158,6 +158,9 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( VertexEmbedding, ) +from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import ( + VertexAIModelGardenModels, +) from .llms.watsonx.chat.handler import WatsonXChatHandler from .llms.watsonx.completion.handler import IBMWatsonXAI from .types.llms.openai import ( @@ -221,6 +224,7 @@ vertex_multimodal_embedding = VertexMultimodalEmbedding() vertex_image_generation = VertexImageGeneration() google_batch_embeddings = GoogleBatchEmbeddings() vertex_partner_models_chat_completion = VertexAIPartnerModels() +vertex_model_garden_chat_completion = VertexAIModelGardenModels() vertex_text_to_speech = VertexTextToSpeechAPI() watsonxai = IBMWatsonXAI() sagemaker_llm = SagemakerLLM() @@ -2355,6 +2359,28 @@ def completion( # type: ignore # noqa: PLR0915 api_base=api_base, extra_headers=extra_headers, ) + elif "openai" in model: + # Vertex Model Garden - OpenAI compatible models + model_response = vertex_model_garden_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + api_base=api_base, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) else: model_response = vertex_ai_non_gemini.completion( model=model, diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 5a07d17b7..e8fb67478 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -3123,3 +3123,85 @@ async def test_vertexai_embedding_finetuned(respx_mock: MockRouter): assert isinstance(embedding["embedding"], list) assert len(embedding["embedding"]) > 0 assert all(isinstance(x, float) for x in embedding["embedding"]) + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_vertexai_model_garden_model_completion(respx_mock: MockRouter): + """ + Relevant issue: https://github.com/BerriAI/litellm/issues/6480 + + Using OpenAI compatible models from Vertex Model Garden + """ + load_vertex_ai_credentials() + litellm.set_verbose = True + + # Test input + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + # Expected request/response + expected_url = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/633608382793/locations/us-central1/endpoints/5464397967697903616/chat/completions" + expected_request = {"model": "", "messages": messages, "stream": False} + + mock_response = { + "id": "chat-09940d4e99e3488aa52a6f5e2ecf35b1", + "object": "chat.completion", + "created": 1731702782, + "model": "meta-llama/Llama-3.1-8B-Instruct", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello, my name is Litellm Bot. I'm a helpful assistant here to provide information and answer your questions.\n\nTo check the weather for you, I'll need to know your location. Could you please provide me with your city or zip code? That way, I can give you the most accurate and up-to-date weather information.\n\nIf you don't have your location handy, I can also suggest some popular weather websites or apps that you can use to check the weather for your area.\n\nLet me know how I can assist you!", + "tool_calls": [], + }, + "logprobs": None, + "finish_reason": "stop", + "stop_reason": None, + } + ], + "usage": {"prompt_tokens": 63, "total_tokens": 172, "completion_tokens": 109}, + "prompt_logprobs": None, + } + + # Setup mock request + mock_request = respx_mock.post(expected_url).mock( + return_value=httpx.Response(200, json=mock_response) + ) + + # Make request + response = await litellm.acompletion( + model="vertex_ai/openai/5464397967697903616", + messages=messages, + vertex_project="633608382793", + vertex_location="us-central1", + ) + + # Assert request was made correctly + assert mock_request.called + request_body = json.loads(mock_request.calls[0].request.content) + assert request_body == expected_request + + # Assert response structure + assert response.id == "chat-09940d4e99e3488aa52a6f5e2ecf35b1" + assert response.created == 1731702782 + assert response.model == "vertex_ai/meta-llama/Llama-3.1-8B-Instruct" + assert len(response.choices) == 1 + assert response.choices[0].message.role == "assistant" + assert response.choices[0].message.content.startswith( + "Hello, my name is Litellm Bot" + ) + assert response.choices[0].finish_reason == "stop" + assert response.usage.completion_tokens == 109 + assert response.usage.prompt_tokens == 63 + assert response.usage.total_tokens == 172