From e165c0f0db9eefb419582c65e1b1a0a2f4b05433 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Tue, 22 Jul 2025 15:27:33 +0300 Subject: [PATCH] feat: Add Google Vertex AI inference provider support - Add new Vertex AI remote inference provider with litellm integration - Support for Gemini models through Google Cloud Vertex AI platform - Uses Google Cloud Application Default Credentials (ADC) for authentication - Added VertexAI models: gemini-2.5-flash, gemini-2.5-pro, gemini-2.0-flash. - Updated provider registry to include vertexai provider - Updated starter template to support Vertex AI configuration - Added comprehensive documentation and sample configuration Signed-off-by: Eran Cohen --- docs/source/providers/inference/index.md | 1 + .../providers/inference/remote_vertexai.md | 40 ++++++++++++ llama_stack/providers/registry/inference.py | 30 +++++++++ .../remote/inference/vertexai/__init__.py | 15 +++++ .../remote/inference/vertexai/config.py | 45 +++++++++++++ .../remote/inference/vertexai/models.py | 20 ++++++ .../remote/inference/vertexai/vertexai.py | 63 +++++++++++++++++++ llama_stack/templates/starter/starter.py | 8 +++ .../inference/test_text_inference.py | 1 + 9 files changed, 223 insertions(+) create mode 100644 docs/source/providers/inference/remote_vertexai.md create mode 100644 llama_stack/providers/remote/inference/vertexai/__init__.py create mode 100644 llama_stack/providers/remote/inference/vertexai/config.py create mode 100644 llama_stack/providers/remote/inference/vertexai/models.py create mode 100644 llama_stack/providers/remote/inference/vertexai/vertexai.py diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index dcc6da5b5..a1340a45d 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -22,5 +22,6 @@ This section contains documentation for all available providers for the **infere - [remote::sambanova](remote_sambanova.md) - [remote::tgi](remote_tgi.md) - [remote::together](remote_together.md) +- [remote::vertexai](remote_vertexai.md) - [remote::vllm](remote_vllm.md) - [remote::watsonx](remote_watsonx.md) \ No newline at end of file diff --git a/docs/source/providers/inference/remote_vertexai.md b/docs/source/providers/inference/remote_vertexai.md new file mode 100644 index 000000000..497b8693d --- /dev/null +++ b/docs/source/providers/inference/remote_vertexai.md @@ -0,0 +1,40 @@ +# remote::vertexai + +## Description + +Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `project` | `` | No | PydanticUndefined | Google Cloud project ID for Vertex AI | +| `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | + +## Sample Configuration + +```yaml +project: ${env.VERTEX_AI_PROJECT} +location: ${env.VERTEX_AI_LOCATION:=us-central1} + +``` + diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a8bc96a77..1801cdcad 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]: description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vertexai", + pip_packages=["litellm", "google-cloud-aiplatform"], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro""", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/vertexai/__init__.py b/llama_stack/providers/remote/inference/vertexai/__init__.py new file mode 100644 index 000000000..d9e9419be --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import VertexAIConfig + + +async def get_adapter_impl(config: VertexAIConfig, _deps): + from .vertexai import VertexAIInferenceAdapter + + impl = VertexAIInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py new file mode 100644 index 000000000..19a61f23d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +class VertexAIProviderDataValidator(BaseModel): + vertex_project: str | None = Field( + default=None, + description="Google Cloud project ID for Vertex AI", + ) + vertex_location: str | None = Field( + default=None, + description="Google Cloud location for Vertex AI (e.g., us-central1)", + ) + + +@json_schema_type +class VertexAIConfig(BaseModel): + project: str = Field( + description="Google Cloud project ID for Vertex AI", + ) + location: str = Field( + default="us-central1", + description="Google Cloud location for Vertex AI", + ) + + @classmethod + def sample_run_config( + cls, + project: str = "${env.VERTEX_AI_PROJECT}", + location: str = "${env.VERTEX_AI_LOCATION:=us-central1}", + **kwargs, + ) -> dict[str, Any]: + return { + "project": project, + "location": location, + } diff --git a/llama_stack/providers/remote/inference/vertexai/models.py b/llama_stack/providers/remote/inference/vertexai/models.py new file mode 100644 index 000000000..e72db533d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/models.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +# Vertex AI model IDs with vertex_ai/ prefix as required by litellm +LLM_MODEL_IDS = [ + "vertex_ai/gemini-2.0-flash", + "vertex_ai/gemini-2.5-flash", + "vertex_ai/gemini-2.5-pro", +] + +SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py new file mode 100644 index 000000000..15af1b95b --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any + +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.providers.utils.inference.litellm_openai_mixin import ( + LiteLLMOpenAIMixin, +) + +from .config import VertexAIConfig +from .models import MODEL_ENTRIES + + +class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: VertexAIConfig) -> None: + # Set environment variables for litellm to use + os.environ["VERTEX_AI_PROJECT"] = config.project + os.environ["VERTEX_AI_LOCATION"] = config.location + + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + litellm_provider_name="vertexai", + api_key_from_config=None, # Vertex AI uses ADC, not API keys + provider_data_api_key_field="vertex_project", # Use project for validation + ) + self.config = config + + async def initialize(self) -> None: + await super().initialize() + + async def shutdown(self) -> None: + await super().shutdown() + + def get_api_key(self) -> str: + # Vertex AI doesn't use API keys, it uses Application Default Credentials + # Return empty string to let litellm handle authentication via ADC + return "" + + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) + + # Add Vertex AI specific parameters + provider_data = self.get_request_provider_data() + if provider_data: + if getattr(provider_data, "vertex_project", None): + params["vertex_project"] = provider_data.vertex_project + if getattr(provider_data, "vertex_location", None): + params["vertex_location"] = provider_data.vertex_location + else: + params["vertex_project"] = self.config.project + params["vertex_location"] = self.config.location + + # Remove api_key since Vertex AI uses ADC + params.pop("api_key", None) + + return params diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index d0782797f..2389613ef 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -245,6 +245,14 @@ def get_distribution_template() -> DistributionTemplate: "", "Gemini API Key", ), + "VERTEX_AI_PROJECT": ( + "", + "Google Cloud Project ID for Vertex AI", + ), + "VERTEX_AI_LOCATION": ( + "us-central1", + "Google Cloud Location for Vertex AI", + ), "SAMBANOVA_API_KEY": ( "", "SambaNova API Key", diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 08e19726e..d7ffe5929 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -29,6 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): "remote::openai", "remote::anthropic", "remote::gemini", + "remote::vertexai", "remote::groq", "remote::sambanova", )