diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 4176f85a6..541fbb432 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="vertexai", - pip_packages=["litellm", "google-cloud-aiplatform"], + pip_packages=["litellm", "google-cloud-aiplatform", "openai"], module="llama_stack.providers.remote.inference.vertexai", config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py index 8807fd0e6..bf8ec59a0 100644 --- a/llama_stack/providers/remote/inference/vertexai/vertexai.py +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -10,12 +10,13 @@ from llama_stack.apis.inference import ChatCompletionRequest from llama_stack.providers.utils.inference.litellm_openai_mixin import ( LiteLLMOpenAIMixin, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import VertexAIConfig from .models import MODEL_ENTRIES -class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): +class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: VertexAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -31,6 +32,10 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): # Return empty string to let litellm handle authentication via ADC return "" + def get_base_url(self): + # source - https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai + return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi" + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: # Get base parameters from parent params = await super()._get_params(request)