mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: update the vertexai inference impl to use openai-python for openai-compat functions (#3377)
# What does this PR do? update VertexAI inference provider to use openai-python for openai-compat functions ## Test Plan ``` $ VERTEX_AI_PROJECT=... uv run llama stack build --image-type venv --providers inference=remote::vertexai --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model vertexai/vertex_ai/gemini-2.5-flash tests/integration/inference/test_openai_completion.py ... ``` i don't have an account to test this. `get_api_key` may also need to be updated per https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai --------- Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
c836fa29e3
commit
0e27016cf2
3 changed files with 33 additions and 5 deletions
|
@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="vertexai",
|
adapter_type="vertexai",
|
||||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
|
||||||
module="llama_stack.providers.remote.inference.vertexai",
|
module="llama_stack.providers.remote.inference.vertexai",
|
||||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||||
|
|
|
@ -6,16 +6,20 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import google.auth.transport.requests
|
||||||
|
from google.auth import default
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest
|
from llama_stack.apis.inference import ChatCompletionRequest
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||||
LiteLLMOpenAIMixin,
|
LiteLLMOpenAIMixin,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import VertexAIConfig
|
from .config import VertexAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
def __init__(self, config: VertexAIConfig) -> None:
|
def __init__(self, config: VertexAIConfig) -> None:
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -27,9 +31,30 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
# Vertex AI doesn't use API keys, it uses Application Default Credentials
|
"""
|
||||||
# Return empty string to let litellm handle authentication via ADC
|
Get an access token for Vertex AI using Application Default Credentials.
|
||||||
return ""
|
|
||||||
|
Vertex AI uses ADC instead of API keys. This method obtains an access token
|
||||||
|
from the default credentials and returns it for use with the OpenAI-compatible client.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
|
||||||
|
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||||
|
credentials.refresh(google.auth.transport.requests.Request())
|
||||||
|
return credentials.token
|
||||||
|
except Exception:
|
||||||
|
# If we can't get credentials, return empty string to let LiteLLM handle it
|
||||||
|
# This allows the LiteLLM mixin to work with ADC directly
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the Vertex AI OpenAI-compatible API base URL.
|
||||||
|
|
||||||
|
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
||||||
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||||
|
"""
|
||||||
|
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||||
# Get base parameters from parent
|
# Get base parameters from parent
|
||||||
|
|
|
@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
|
||||||
"remote::gemini",
|
"remote::gemini",
|
||||||
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
|
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
|
||||||
"remote::anthropic",
|
"remote::anthropic",
|
||||||
|
"remote::vertexai",
|
||||||
|
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
|
||||||
|
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue