From 1f421238b811b6cbe7e31e44bde7d1fd5dfc65f8 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Tue, 22 Jul 2025 15:27:33 +0300 Subject: [PATCH] feat: Add Google Vertex AI inference provider support - Add new Vertex AI remote inference provider with litellm integration - Support for Gemini models through Google Cloud Vertex AI platform - Uses Google Cloud Application Default Credentials (ADC) for authentication - Added VertexAI models: gemini-2.5-flash, gemini-2.5-pro, gemini-2.0-flash. - Updated provider registry to include vertexai provider - Updated starter template to support Vertex AI configuration - Added comprehensive documentation and sample configuration Signed-off-by: Eran Cohen --- docs/source/providers/inference/index.md | 1 + .../providers/inference/remote_vertexai.md | 40 ++++++++++++ llama_stack/providers/registry/inference.py | 30 +++++++++ .../remote/inference/vertexai/__init__.py | 15 +++++ .../remote/inference/vertexai/config.py | 45 ++++++++++++++ .../remote/inference/vertexai/models.py | 36 +++++++++++ .../remote/inference/vertexai/vertexai.py | 62 +++++++++++++++++++ llama_stack/templates/ci-tests/build.yaml | 1 + llama_stack/templates/ci-tests/run.yaml | 34 ++++++++++ llama_stack/templates/starter/build.yaml | 1 + llama_stack/templates/starter/run.yaml | 34 ++++++++++ llama_stack/templates/starter/starter.py | 12 ++++ 12 files changed, 311 insertions(+) create mode 100644 docs/source/providers/inference/remote_vertexai.md create mode 100644 llama_stack/providers/remote/inference/vertexai/__init__.py create mode 100644 llama_stack/providers/remote/inference/vertexai/config.py create mode 100644 llama_stack/providers/remote/inference/vertexai/models.py create mode 100644 llama_stack/providers/remote/inference/vertexai/vertexai.py diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index dcc6da5b5..a1340a45d 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -22,5 +22,6 @@ This section contains documentation for all available providers for the **infere - [remote::sambanova](remote_sambanova.md) - [remote::tgi](remote_tgi.md) - [remote::together](remote_together.md) +- [remote::vertexai](remote_vertexai.md) - [remote::vllm](remote_vllm.md) - [remote::watsonx](remote_watsonx.md) \ No newline at end of file diff --git a/docs/source/providers/inference/remote_vertexai.md b/docs/source/providers/inference/remote_vertexai.md new file mode 100644 index 000000000..497b8693d --- /dev/null +++ b/docs/source/providers/inference/remote_vertexai.md @@ -0,0 +1,40 @@ +# remote::vertexai + +## Description + +Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `project` | `` | No | PydanticUndefined | Google Cloud project ID for Vertex AI | +| `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | + +## Sample Configuration + +```yaml +project: ${env.VERTEX_AI_PROJECT} +location: ${env.VERTEX_AI_LOCATION:=us-central1} + +``` + diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a8bc96a77..1801cdcad 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]: description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vertexai", + pip_packages=["litellm", "google-cloud-aiplatform"], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro""", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/vertexai/__init__.py b/llama_stack/providers/remote/inference/vertexai/__init__.py new file mode 100644 index 000000000..d9e9419be --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import VertexAIConfig + + +async def get_adapter_impl(config: VertexAIConfig, _deps): + from .vertexai import VertexAIInferenceAdapter + + impl = VertexAIInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py new file mode 100644 index 000000000..19a61f23d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +class VertexAIProviderDataValidator(BaseModel): + vertex_project: str | None = Field( + default=None, + description="Google Cloud project ID for Vertex AI", + ) + vertex_location: str | None = Field( + default=None, + description="Google Cloud location for Vertex AI (e.g., us-central1)", + ) + + +@json_schema_type +class VertexAIConfig(BaseModel): + project: str = Field( + description="Google Cloud project ID for Vertex AI", + ) + location: str = Field( + default="us-central1", + description="Google Cloud location for Vertex AI", + ) + + @classmethod + def sample_run_config( + cls, + project: str = "${env.VERTEX_AI_PROJECT}", + location: str = "${env.VERTEX_AI_LOCATION:=us-central1}", + **kwargs, + ) -> dict[str, Any]: + return { + "project": project, + "location": location, + } diff --git a/llama_stack/providers/remote/inference/vertexai/models.py b/llama_stack/providers/remote/inference/vertexai/models.py new file mode 100644 index 000000000..450069299 --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/models.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models import ModelType +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +# Vertex AI model IDs with vertex_ai/ prefix as required by litellm +LLM_MODEL_IDS = [ + "vertex_ai/gemini-2.0-flash", + "vertex_ai/gemini-2.5-flash", + "vertex_ai/gemini-2.5-pro", +] + +SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() + +MODEL_ENTRIES = ( + [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + + [ + ProviderModelEntry( + provider_model_id="vertex_ai/text-embedding-004", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 768, "context_length": 2048}, + ), + ProviderModelEntry( + provider_model_id="vertex_ai/text-embedding-005", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 768, "context_length": 2048}, + ), + ] + + SAFETY_MODELS_ENTRIES +) diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py new file mode 100644 index 000000000..2e9836537 --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any + +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.providers.utils.inference.litellm_openai_mixin import ( + LiteLLMOpenAIMixin, +) + +from .config import VertexAIConfig +from .models import MODEL_ENTRIES + + +class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: VertexAIConfig) -> None: + # Set environment variables for litellm to use + os.environ["VERTEX_AI_PROJECT"] = config.project + os.environ["VERTEX_AI_LOCATION"] = config.location + + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + api_key_from_config=None, # Vertex AI uses ADC, not API keys + provider_data_api_key_field="vertex_project", # Use project for validation + ) + self.config = config + + async def initialize(self) -> None: + await super().initialize() + + async def shutdown(self) -> None: + await super().shutdown() + + def get_api_key(self) -> str: + # Vertex AI doesn't use API keys, it uses Application Default Credentials + # Return empty string to let litellm handle authentication via ADC + return "" + + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) + + # Add Vertex AI specific parameters + provider_data = self.get_request_provider_data() + if provider_data: + if getattr(provider_data, "vertex_project", None): + params["vertex_project"] = provider_data.vertex_project + if getattr(provider_data, "vertex_location", None): + params["vertex_location"] = provider_data.vertex_location + else: + params["vertex_project"] = self.config.project + params["vertex_location"] = self.config.location + + # Remove api_key since Vertex AI uses ADC + params.pop("api_key", None) + + return params diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 625e36e4f..4a0e3d9f6 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -18,6 +18,7 @@ distribution_spec: - remote::openai - remote::anthropic - remote::gemini + - remote::vertexai - remote::groq - remote::llama-openai-compat - remote::sambanova diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 3757c6e60..aaa15f61b 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -85,6 +85,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY} + - provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT} + location: ${env.VERTEX_AI_LOCATION:=us-central1} - provider_id: ${env.ENABLE_GROQ:=__disabled__} provider_type: remote::groq config: @@ -963,6 +968,35 @@ models: provider_id: ${env.ENABLE_GEMINI:=__disabled__} provider_model_id: gemini/text-embedding-004 model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/gemini-2.0-flash + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/gemini-2.0-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/gemini-2.5-flash + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/gemini-2.5-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/gemini-2.5-pro + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/gemini-2.5-pro + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 2048 + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/text-embedding-004 + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/text-embedding-004 + model_type: embedding +- metadata: + embedding_dimension: 768 + context_length: 2048 + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/text-embedding-005 + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/text-embedding-005 + model_type: embedding - metadata: {} model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama3-8b-8192 provider_id: ${env.ENABLE_GROQ:=__disabled__} diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 8180124f6..7e796ac60 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -18,6 +18,7 @@ distribution_spec: - remote::openai - remote::anthropic - remote::gemini + - remote::vertexai - remote::groq - remote::llama-openai-compat - remote::sambanova diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 62e96d3b5..e7ab31534 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -85,6 +85,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY} + - provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT} + location: ${env.VERTEX_AI_LOCATION:=us-central1} - provider_id: ${env.ENABLE_GROQ:=__disabled__} provider_type: remote::groq config: @@ -963,6 +968,35 @@ models: provider_id: ${env.ENABLE_GEMINI:=__disabled__} provider_model_id: gemini/text-embedding-004 model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/gemini-2.0-flash + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/gemini-2.0-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/gemini-2.5-flash + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/gemini-2.5-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/gemini-2.5-pro + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/gemini-2.5-pro + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 2048 + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/text-embedding-004 + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/text-embedding-004 + model_type: embedding +- metadata: + embedding_dimension: 768 + context_length: 2048 + model_id: ${env.ENABLE_VERTEXAI:=__disabled__}/vertex_ai/text-embedding-005 + provider_id: ${env.ENABLE_VERTEXAI:=__disabled__} + provider_model_id: vertex_ai/text-embedding-005 + model_type: embedding - metadata: {} model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama3-8b-8192 provider_id: ${env.ENABLE_GROQ:=__disabled__} diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index ec6e8fdce..7ae64d02c 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -64,6 +64,9 @@ from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.together.models import ( MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, ) +from llama_stack.providers.remote.inference.vertexai.models import ( + MODEL_ENTRIES as VERTEXAI_MODEL_ENTRIES, +) from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, @@ -93,6 +96,7 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt "databricks": DATABRICKS_MODEL_ENTRIES, "nvidia": NVIDIA_MODEL_ENTRIES, "runpod": RUNPOD_MODEL_ENTRIES, + "vertexai": VERTEXAI_MODEL_ENTRIES, } # Special handling for providers with dynamic model entries @@ -354,6 +358,14 @@ def get_distribution_template() -> DistributionTemplate: "", "Gemini API Key", ), + "VERTEX_AI_PROJECT": ( + "", + "Google Cloud Project ID for Vertex AI", + ), + "VERTEX_AI_LOCATION": ( + "us-central1", + "Google Cloud Location for Vertex AI", + ), "SAMBANOVA_API_KEY": ( "", "SambaNova API Key",