From e165c0f0db9eefb419582c65e1b1a0a2f4b05433 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Tue, 22 Jul 2025 15:27:33 +0300 Subject: [PATCH 1/2] feat: Add Google Vertex AI inference provider support - Add new Vertex AI remote inference provider with litellm integration - Support for Gemini models through Google Cloud Vertex AI platform - Uses Google Cloud Application Default Credentials (ADC) for authentication - Added VertexAI models: gemini-2.5-flash, gemini-2.5-pro, gemini-2.0-flash. - Updated provider registry to include vertexai provider - Updated starter template to support Vertex AI configuration - Added comprehensive documentation and sample configuration Signed-off-by: Eran Cohen --- docs/source/providers/inference/index.md | 1 + .../providers/inference/remote_vertexai.md | 40 ++++++++++++ llama_stack/providers/registry/inference.py | 30 +++++++++ .../remote/inference/vertexai/__init__.py | 15 +++++ .../remote/inference/vertexai/config.py | 45 +++++++++++++ .../remote/inference/vertexai/models.py | 20 ++++++ .../remote/inference/vertexai/vertexai.py | 63 +++++++++++++++++++ llama_stack/templates/starter/starter.py | 8 +++ .../inference/test_text_inference.py | 1 + 9 files changed, 223 insertions(+) create mode 100644 docs/source/providers/inference/remote_vertexai.md create mode 100644 llama_stack/providers/remote/inference/vertexai/__init__.py create mode 100644 llama_stack/providers/remote/inference/vertexai/config.py create mode 100644 llama_stack/providers/remote/inference/vertexai/models.py create mode 100644 llama_stack/providers/remote/inference/vertexai/vertexai.py diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index dcc6da5b5..a1340a45d 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -22,5 +22,6 @@ This section contains documentation for all available providers for the **infere - [remote::sambanova](remote_sambanova.md) - [remote::tgi](remote_tgi.md) - [remote::together](remote_together.md) +- [remote::vertexai](remote_vertexai.md) - [remote::vllm](remote_vllm.md) - [remote::watsonx](remote_watsonx.md) \ No newline at end of file diff --git a/docs/source/providers/inference/remote_vertexai.md b/docs/source/providers/inference/remote_vertexai.md new file mode 100644 index 000000000..497b8693d --- /dev/null +++ b/docs/source/providers/inference/remote_vertexai.md @@ -0,0 +1,40 @@ +# remote::vertexai + +## Description + +Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `project` | `` | No | PydanticUndefined | Google Cloud project ID for Vertex AI | +| `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | + +## Sample Configuration + +```yaml +project: ${env.VERTEX_AI_PROJECT} +location: ${env.VERTEX_AI_LOCATION:=us-central1} + +``` + diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a8bc96a77..1801cdcad 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]: description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vertexai", + pip_packages=["litellm", "google-cloud-aiplatform"], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro""", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/vertexai/__init__.py b/llama_stack/providers/remote/inference/vertexai/__init__.py new file mode 100644 index 000000000..d9e9419be --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import VertexAIConfig + + +async def get_adapter_impl(config: VertexAIConfig, _deps): + from .vertexai import VertexAIInferenceAdapter + + impl = VertexAIInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py new file mode 100644 index 000000000..19a61f23d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +class VertexAIProviderDataValidator(BaseModel): + vertex_project: str | None = Field( + default=None, + description="Google Cloud project ID for Vertex AI", + ) + vertex_location: str | None = Field( + default=None, + description="Google Cloud location for Vertex AI (e.g., us-central1)", + ) + + +@json_schema_type +class VertexAIConfig(BaseModel): + project: str = Field( + description="Google Cloud project ID for Vertex AI", + ) + location: str = Field( + default="us-central1", + description="Google Cloud location for Vertex AI", + ) + + @classmethod + def sample_run_config( + cls, + project: str = "${env.VERTEX_AI_PROJECT}", + location: str = "${env.VERTEX_AI_LOCATION:=us-central1}", + **kwargs, + ) -> dict[str, Any]: + return { + "project": project, + "location": location, + } diff --git a/llama_stack/providers/remote/inference/vertexai/models.py b/llama_stack/providers/remote/inference/vertexai/models.py new file mode 100644 index 000000000..e72db533d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/models.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +# Vertex AI model IDs with vertex_ai/ prefix as required by litellm +LLM_MODEL_IDS = [ + "vertex_ai/gemini-2.0-flash", + "vertex_ai/gemini-2.5-flash", + "vertex_ai/gemini-2.5-pro", +] + +SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py new file mode 100644 index 000000000..15af1b95b --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any + +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.providers.utils.inference.litellm_openai_mixin import ( + LiteLLMOpenAIMixin, +) + +from .config import VertexAIConfig +from .models import MODEL_ENTRIES + + +class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: VertexAIConfig) -> None: + # Set environment variables for litellm to use + os.environ["VERTEX_AI_PROJECT"] = config.project + os.environ["VERTEX_AI_LOCATION"] = config.location + + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + litellm_provider_name="vertexai", + api_key_from_config=None, # Vertex AI uses ADC, not API keys + provider_data_api_key_field="vertex_project", # Use project for validation + ) + self.config = config + + async def initialize(self) -> None: + await super().initialize() + + async def shutdown(self) -> None: + await super().shutdown() + + def get_api_key(self) -> str: + # Vertex AI doesn't use API keys, it uses Application Default Credentials + # Return empty string to let litellm handle authentication via ADC + return "" + + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) + + # Add Vertex AI specific parameters + provider_data = self.get_request_provider_data() + if provider_data: + if getattr(provider_data, "vertex_project", None): + params["vertex_project"] = provider_data.vertex_project + if getattr(provider_data, "vertex_location", None): + params["vertex_location"] = provider_data.vertex_location + else: + params["vertex_project"] = self.config.project + params["vertex_location"] = self.config.location + + # Remove api_key since Vertex AI uses ADC + params.pop("api_key", None) + + return params diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index d0782797f..2389613ef 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -245,6 +245,14 @@ def get_distribution_template() -> DistributionTemplate: "", "Gemini API Key", ), + "VERTEX_AI_PROJECT": ( + "", + "Google Cloud Project ID for Vertex AI", + ), + "VERTEX_AI_LOCATION": ( + "us-central1", + "Google Cloud Location for Vertex AI", + ), "SAMBANOVA_API_KEY": ( "", "SambaNova API Key", diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 08e19726e..d7ffe5929 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -29,6 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): "remote::openai", "remote::anthropic", "remote::gemini", + "remote::vertexai", "remote::groq", "remote::sambanova", ) From 84b56aa8b87c028f8537ff9b5e5996a0d7da2763 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Sun, 27 Jul 2025 21:44:54 +0300 Subject: [PATCH 2/2] Update test message content to be more specific With the former message gemini returned "Which San Francisco are you referring to? If it is in the US, pleasespecify the state" Signed-off-by: Eran Cohen --- tests/integration/test_cases/inference/chat_completion.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json index 1ae018397..36de3f361 100644 --- a/tests/integration/test_cases/inference/chat_completion.json +++ b/tests/integration/test_cases/inference/chat_completion.json @@ -78,7 +78,7 @@ }, { "role": "user", - "content": "What's the weather like in San Francisco?" + "content": "What's the weather like in San Francisco CA?" } ], "tools": [