feat(openai): add configurable base_url support with OPENAI_BASE_URL env var

- Add base_url field to OpenAIConfig with default "https://api.openai.com/v1"
- Update sample_run_config to support OPENAI_BASE_URL environment variable
- Modify get_base_url() to return configured base_url instead of hardcoded value
- Add comprehensive test suite covering:
  - Default base URL behavior
  - Custom base URL from config
  - Environment variable override
  - Config precedence over environment variables
  - Client initialization with configured URL
  - Model availability checks using configured URL

This enables users to configure custom OpenAI-compatible API endpoints
via environment variables or configuration files.
This commit is contained in:
Matthew Farrellee 2025-07-26 06:04:00 -04:00
parent 09abdb0a37
commit 79c9e46582
7 changed files with 143 additions and 3 deletions

View file

@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
## Sample Configuration
```yaml
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
```

View file

@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
default=None,
description="API key for OpenAI models",
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]:
def sample_run_config(
cls,
api_key: str = "${env.OPENAI_API_KEY:=}",
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"base_url": base_url,
}

View file

@ -65,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
Get the OpenAI API base URL.
Returns the standard OpenAI API base URL for direct OpenAI API calls.
Returns the OpenAI API base URL from the configuration.
"""
return "https://api.openai.com/v1"
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()

View file

@ -56,6 +56,7 @@ providers:
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:

View file

@ -16,6 +16,7 @@ providers:
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:

View file

@ -56,6 +56,7 @@ providers:
provider_type: remote::openai
config:
api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic
provider_type: remote::anthropic
config:

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
class TestOpenAIBaseURLConfig:
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://api.openai.com/v1"
def test_custom_base_url_from_config(self):
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == custom_url
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_base_url_from_environment_variable(self):
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://env.openai.com/v1"
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_config_overrides_environment_variable(self):
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Config should take precedence over environment variable
assert adapter.get_base_url() == custom_url
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
def test_client_uses_configured_base_url(self, mock_openai_class):
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
# Access the client property to trigger AsyncOpenAI initialization
_ = adapter.client
# Verify AsyncOpenAI was called with the correct base_url
mock_openai_class.assert_called_once_with(
api_key="test-key",
base_url=custom_url,
)
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client and its models.retrieve method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the custom URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url=custom_url,
)
# Verify the method was called and returned True
mock_client.models.retrieve.assert_called_once_with("gpt-4")
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the environment variable URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url="https://proxy.openai.com/v1",
)