From 0e27016cf23eca51a0f025897b44109b1b609b71 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 10 Sep 2025 09:39:29 -0400 Subject: [PATCH] chore: update the vertexai inference impl to use openai-python for openai-compat functions (#3377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? update VertexAI inference provider to use openai-python for openai-compat functions ## Test Plan ``` $ VERTEX_AI_PROJECT=... uv run llama stack build --image-type venv --providers inference=remote::vertexai --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model vertexai/vertex_ai/gemini-2.5-flash tests/integration/inference/test_openai_completion.py ... ``` i don't have an account to test this. `get_api_key` may also need to be updated per https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai --------- Signed-off-by: Sébastien Han Co-authored-by: Sébastien Han --- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/vertexai/vertexai.py | 33 ++++++++++++++++--- .../inference/test_openai_completion.py | 3 ++ 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 4176f85a6..541fbb432 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="vertexai", - pip_packages=["litellm", "google-cloud-aiplatform"], + pip_packages=["litellm", "google-cloud-aiplatform", "openai"], module="llama_stack.providers.remote.inference.vertexai", config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py index 8807fd0e6..27f953ab9 100644 --- a/llama_stack/providers/remote/inference/vertexai/vertexai.py +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -6,16 +6,20 @@ from typing import Any +import google.auth.transport.requests +from google.auth import default + from llama_stack.apis.inference import ChatCompletionRequest from llama_stack.providers.utils.inference.litellm_openai_mixin import ( LiteLLMOpenAIMixin, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import VertexAIConfig from .models import MODEL_ENTRIES -class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): +class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: VertexAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -27,9 +31,30 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): self.config = config def get_api_key(self) -> str: - # Vertex AI doesn't use API keys, it uses Application Default Credentials - # Return empty string to let litellm handle authentication via ADC - return "" + """ + Get an access token for Vertex AI using Application Default Credentials. + + Vertex AI uses ADC instead of API keys. This method obtains an access token + from the default credentials and returns it for use with the OpenAI-compatible client. + """ + try: + # Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS + credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + credentials.refresh(google.auth.transport.requests.Request()) + return credentials.token + except Exception: + # If we can't get credentials, return empty string to let LiteLLM handle it + # This allows the LiteLLM mixin to work with ADC directly + return "" + + def get_base_url(self) -> str: + """ + Get the Vertex AI OpenAI-compatible API base URL. + + Returns the Vertex AI OpenAI-compatible endpoint URL. + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai + """ + return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi" async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: # Get base parameters from parent diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index df1184f1c..f9c837ebd 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id): "remote::gemini", # https://docs.anthropic.com/en/api/openai-sdk#simple-fields "remote::anthropic", + "remote::vertexai", + # Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but + # the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'} ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")