fix: add token to the openai request

OpenAIMixin expects to use an API key and creates its own AsyncOpenAI
client. So our code now authenticate with the Google service, retrieves
a token and pass it to the OpenAI client.
Falls back to an empty string if credentials can't be obtained (letting
LiteLLM handle ADC directly).

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-09-10 15:17:37 +02:00
parent 2f18194978
commit b9961c8735
No known key found for this signature in database
2 changed files with 28 additions and 5 deletions

View file

@ -6,6 +6,9 @@
from typing import Any from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import ( from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin, LiteLLMOpenAIMixin,
@ -28,12 +31,29 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
self.config = config self.config = config
def get_api_key(self) -> str: def get_api_key(self) -> str:
# Vertex AI doesn't use API keys, it uses Application Default Credentials """
# Return empty string to let litellm handle authentication via ADC Get an access token for Vertex AI using Application Default Credentials.
return ""
def get_base_url(self): Vertex AI uses ADC instead of API keys. This method obtains an access token
# source - https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai from the default credentials and returns it for use with the OpenAI-compatible client.
"""
try:
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())
return credentials.token
except Exception:
# If we can't get credentials, return empty string to let LiteLLM handle it
# This allows the LiteLLM mixin to work with ADC directly
return ""
def get_base_url(self) -> str:
"""
Get the Vertex AI OpenAI-compatible API base URL.
Returns the Vertex AI OpenAI-compatible endpoint URL.
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi" return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:

View file

@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
"remote::gemini", "remote::gemini",
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields # https://docs.anthropic.com/en/api/openai-sdk#simple-fields
"remote::anthropic", "remote::anthropic",
"remote::vertexai",
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")