From 79c9e46582477b419b862c63f1cf99f32601cc3e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 26 Jul 2025 06:04:00 -0400 Subject: [PATCH] feat(openai): add configurable base_url support with OPENAI_BASE_URL env var - Add base_url field to OpenAIConfig with default "https://api.openai.com/v1" - Update sample_run_config to support OPENAI_BASE_URL environment variable - Modify get_base_url() to return configured base_url instead of hardcoded value - Add comprehensive test suite covering: - Default base URL behavior - Custom base URL from config - Environment variable override - Config precedence over environment variables - Client initialization with configured URL - Model availability checks using configured URL This enables users to configure custom OpenAI-compatible API endpoints via environment variables or configuration files. --- .../providers/inference/remote_openai.md | 2 + .../remote/inference/openai/config.py | 12 +- .../remote/inference/openai/openai.py | 4 +- llama_stack/templates/ci-tests/run.yaml | 1 + llama_stack/templates/open-benchmark/run.yaml | 1 + llama_stack/templates/starter/run.yaml | 1 + .../inference/test_openai_base_url_config.py | 125 ++++++++++++++++++ 7 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 tests/unit/providers/inference/test_openai_base_url_config.py diff --git a/docs/source/providers/inference/remote_openai.md b/docs/source/providers/inference/remote_openai.md index 36e4b5454..18a74caea 100644 --- a/docs/source/providers/inference/remote_openai.md +++ b/docs/source/providers/inference/remote_openai.md @@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `api_key` | `str \| None` | No | | API key for OpenAI models | +| `base_url` | `` | No | https://api.openai.com/v1 | Base URL for OpenAI API | ## Sample Configuration ```yaml api_key: ${env.OPENAI_API_KEY:=} +base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} ``` diff --git a/llama_stack/providers/remote/inference/openai/config.py b/llama_stack/providers/remote/inference/openai/config.py index 2768e98d0..ad25cdfa5 100644 --- a/llama_stack/providers/remote/inference/openai/config.py +++ b/llama_stack/providers/remote/inference/openai/config.py @@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel): default=None, description="API key for OpenAI models", ) + base_url: str = Field( + default="https://api.openai.com/v1", + description="Base URL for OpenAI API", + ) @classmethod - def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]: + def sample_run_config( + cls, + api_key: str = "${env.OPENAI_API_KEY:=}", + base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}", + **kwargs, + ) -> dict[str, Any]: return { "api_key": api_key, + "base_url": base_url, } diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index f5d4afe3f..865258559 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -65,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): """ Get the OpenAI API base URL. - Returns the standard OpenAI API base URL for direct OpenAI API calls. + Returns the OpenAI API base URL from the configuration. """ - return "https://api.openai.com/v1" + return self.config.base_url async def initialize(self) -> None: await super().initialize() diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 2a1270107..84eacae1f 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -56,6 +56,7 @@ providers: provider_type: remote::openai config: api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} - provider_id: anthropic provider_type: remote::anthropic config: diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 4e635d80f..779bca47e 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -16,6 +16,7 @@ providers: provider_type: remote::openai config: api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} - provider_id: anthropic provider_type: remote::anthropic config: diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 40e43cde9..0b7e71a75 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -56,6 +56,7 @@ providers: provider_type: remote::openai config: api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} - provider_id: anthropic provider_type: remote::anthropic config: diff --git a/tests/unit/providers/inference/test_openai_base_url_config.py b/tests/unit/providers/inference/test_openai_base_url_config.py new file mode 100644 index 000000000..453ac9089 --- /dev/null +++ b/tests/unit/providers/inference/test_openai_base_url_config.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +from llama_stack.distribution.stack import replace_env_vars +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter + + +class TestOpenAIBaseURLConfig: + """Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter.""" + + def test_default_base_url_without_env_var(self): + """Test that the adapter uses the default OpenAI base URL when no environment variable is set.""" + config = OpenAIConfig(api_key="test-key") + adapter = OpenAIInferenceAdapter(config) + + assert adapter.get_base_url() == "https://api.openai.com/v1" + + def test_custom_base_url_from_config(self): + """Test that the adapter uses a custom base URL when provided in config.""" + custom_url = "https://custom.openai.com/v1" + config = OpenAIConfig(api_key="test-key", base_url=custom_url) + adapter = OpenAIInferenceAdapter(config) + + assert adapter.get_base_url() == custom_url + + @patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"}) + def test_base_url_from_environment_variable(self): + """Test that the adapter uses base URL from OPENAI_BASE_URL environment variable.""" + # Use sample_run_config which has proper environment variable syntax + config_data = OpenAIConfig.sample_run_config(api_key="test-key") + processed_config = replace_env_vars(config_data) + config = OpenAIConfig.model_validate(processed_config) + adapter = OpenAIInferenceAdapter(config) + + assert adapter.get_base_url() == "https://env.openai.com/v1" + + @patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"}) + def test_config_overrides_environment_variable(self): + """Test that explicit config value overrides environment variable.""" + custom_url = "https://config.openai.com/v1" + config = OpenAIConfig(api_key="test-key", base_url=custom_url) + adapter = OpenAIInferenceAdapter(config) + + # Config should take precedence over environment variable + assert adapter.get_base_url() == custom_url + + @patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") + def test_client_uses_configured_base_url(self, mock_openai_class): + """Test that the OpenAI client is initialized with the configured base URL.""" + custom_url = "https://test.openai.com/v1" + config = OpenAIConfig(api_key="test-key", base_url=custom_url) + adapter = OpenAIInferenceAdapter(config) + + # Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin + adapter.get_api_key = MagicMock(return_value="test-key") + + # Access the client property to trigger AsyncOpenAI initialization + _ = adapter.client + + # Verify AsyncOpenAI was called with the correct base_url + mock_openai_class.assert_called_once_with( + api_key="test-key", + base_url=custom_url, + ) + + @patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") + async def test_check_model_availability_uses_configured_url(self, mock_openai_class): + """Test that check_model_availability uses the configured base URL.""" + custom_url = "https://test.openai.com/v1" + config = OpenAIConfig(api_key="test-key", base_url=custom_url) + adapter = OpenAIInferenceAdapter(config) + + # Mock the get_api_key method + adapter.get_api_key = MagicMock(return_value="test-key") + + # Mock the AsyncOpenAI client and its models.retrieve method + mock_client = MagicMock() + mock_client.models.retrieve = AsyncMock(return_value=MagicMock()) + mock_openai_class.return_value = mock_client + + # Call check_model_availability and verify it returns True + assert await adapter.check_model_availability("gpt-4") + + # Verify the client was created with the custom URL + mock_openai_class.assert_called_with( + api_key="test-key", + base_url=custom_url, + ) + + # Verify the method was called and returned True + mock_client.models.retrieve.assert_called_once_with("gpt-4") + + @patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"}) + @patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") + async def test_environment_variable_affects_model_availability_check(self, mock_openai_class): + """Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked.""" + # Use sample_run_config which has proper environment variable syntax + config_data = OpenAIConfig.sample_run_config(api_key="test-key") + processed_config = replace_env_vars(config_data) + config = OpenAIConfig.model_validate(processed_config) + adapter = OpenAIInferenceAdapter(config) + + # Mock the get_api_key method + adapter.get_api_key = MagicMock(return_value="test-key") + + # Mock the AsyncOpenAI client + mock_client = MagicMock() + mock_client.models.retrieve = AsyncMock(return_value=MagicMock()) + mock_openai_class.return_value = mock_client + + # Call check_model_availability and verify it returns True + assert await adapter.check_model_availability("gpt-4") + + # Verify the client was created with the environment variable URL + mock_openai_class.assert_called_with( + api_key="test-key", + base_url="https://proxy.openai.com/v1", + )