feat: Add Google Vertex AI inference provider support

- Add new Vertex AI remote inference provider with litellm integration
- Support for Gemini models through Google Cloud Vertex AI platform
- Uses Google Cloud Application Default Credentials (ADC) for authentication
- Added VertexAI models: gemini-2.5-flash, gemini-2.5-pro, gemini-2.0-flash.
- Updated provider registry to include vertexai provider
- Updated starter template to support Vertex AI configuration
- Added comprehensive documentation and sample configuration

Signed-off-by: Eran Cohen <eranco@redhat.com>
This commit is contained in:
Eran Cohen 2025-07-22 15:27:33 +03:00
parent c0563c0560
commit 1f421238b8
12 changed files with 311 additions and 0 deletions

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from .config import VertexAIConfig
from .models import MODEL_ENTRIES
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
# Set environment variables for litellm to use
os.environ["VERTEX_AI_PROJECT"] = config.project
os.environ["VERTEX_AI_LOCATION"] = config.location
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
def get_api_key(self) -> str:
# Vertex AI doesn't use API keys, it uses Application Default Credentials
# Return empty string to let litellm handle authentication via ADC
return ""
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Vertex AI specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "vertex_project", None):
params["vertex_project"] = provider_data.vertex_project
if getattr(provider_data, "vertex_location", None):
params["vertex_location"] = provider_data.vertex_location
else:
params["vertex_project"] = self.config.project
params["vertex_location"] = self.config.location
# Remove api_key since Vertex AI uses ADC
params.pop("api_key", None)
return params