Merge branch 'main' into auto_instrument_1

This commit is contained in:
Emilio Garcia 2025-11-20 16:12:57 -05:00 committed by GitHub
commit d8b883be41
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 150 additions and 15 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 @franciscojavierarceo * @ashwinb @raghotham @ehhuang @leseb @bbrowning @mattf @franciscojavierarceo

View file

@ -20,6 +20,7 @@ TEST_PATTERN=""
INFERENCE_MODE="replay" INFERENCE_MODE="replay"
EXTRA_PARAMS="" EXTRA_PARAMS=""
COLLECT_ONLY=false COLLECT_ONLY=false
TYPESCRIPT_ONLY=false
# Function to display usage # Function to display usage
usage() { usage() {
@ -34,6 +35,7 @@ Options:
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite) --subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--pattern STRING Regex pattern to pass to pytest -k --pattern STRING Regex pattern to pass to pytest -k
--collect-only Collect tests only without running them (skips server startup) --collect-only Collect tests only without running them (skips server startup)
--typescript-only Skip Python tests and run only TypeScript client tests
--help Show this help message --help Show this help message
Suites are defined in tests/integration/suites.py and define which tests to run. Suites are defined in tests/integration/suites.py and define which tests to run.
@ -90,6 +92,10 @@ while [[ $# -gt 0 ]]; do
COLLECT_ONLY=true COLLECT_ONLY=true
shift shift
;; ;;
--typescript-only)
TYPESCRIPT_ONLY=true
shift
;;
--help) --help)
usage usage
exit 0 exit 0
@ -552,6 +558,8 @@ if [[ -n "$STACK_CONFIG" ]]; then
STACK_CONFIG_ARG="--stack-config=$STACK_CONFIG" STACK_CONFIG_ARG="--stack-config=$STACK_CONFIG"
fi fi
# Run Python tests unless typescript-only mode
if [[ "$TYPESCRIPT_ONLY" == "false" ]]; then
pytest -s -v $PYTEST_TARGET \ pytest -s -v $PYTEST_TARGET \
$STACK_CONFIG_ARG \ $STACK_CONFIG_ARG \
--inference-mode="$INFERENCE_MODE" \ --inference-mode="$INFERENCE_MODE" \
@ -562,6 +570,11 @@ pytest -s -v $PYTEST_TARGET \
--color=yes $EXTRA_PARAMS \ --color=yes $EXTRA_PARAMS \
--capture=tee-sys --capture=tee-sys
exit_code=$? exit_code=$?
else
echo "Skipping Python tests (--typescript-only mode)"
exit_code=0
fi
set +x set +x
set -e set -e

View file

@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
return api_key return api_key
def _validate_model_allowed(self, provider_model_id: str) -> None:
"""
Validate that the model is in the allowed_models list if configured.
:param provider_model_id: The provider-specific model ID to validate
:raises ValueError: If the model is not in the allowed_models list
"""
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
raise ValueError(
f"Model '{provider_model_id}' is not in the allowed models list. "
f"Allowed models: {self.config.allowed_models}"
)
async def _get_provider_model_id(self, model: str) -> str: async def _get_provider_model_id(self, model: str) -> str:
""" """
Get the provider-specific model ID from the model store. Get the provider-specific model ID from the model store.
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
completion_kwargs = await prepare_openai_completion_params( completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model), model=provider_model_id,
prompt=params.prompt, prompt=params.prompt,
best_of=params.best_of, best_of=params.best_of,
echo=params.echo, echo=params.echo,
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
messages = params.messages messages = params.messages
if self.download_images: if self.download_images:
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages] messages = [await _localize_image_url(m) for m in messages]
request_params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model), model=provider_model_id,
messages=messages, messages=messages,
frequency_penalty=params.frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=params.function_call, function_call=params.function_call,
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI embeddings API call. Direct OpenAI embeddings API call.
""" """
provider_model_id = await self._get_provider_model_id(params.model)
self._validate_model_allowed(provider_model_id)
# Build request params conditionally to avoid NotGiven/Omit type mismatch # Build request params conditionally to avoid NotGiven/Omit type mismatch
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
request_params: dict[str, Any] = { request_params: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model), "model": provider_model_id,
"input": params.input, "input": params.input,
} }
if params.encoding_format is not None: if params.encoding_format is not None:

View file

@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam from llama_stack_api import (
Model,
ModelType,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIUserMessageParam,
)
class OpenAIMixinImpl(OpenAIMixin): class OpenAIMixinImpl(OpenAIMixin):
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
error_message = str(exc_info.value) error_message = str(exc_info.value)
assert "test_api_key" in error_message assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message assert "x-llamastack-provider-data" in error_message
class TestOpenAIMixinAllowedModelsInference:
"""Test cases for allowed_models enforcement during inference requests"""
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
"""Test that all inference methods succeed with allowed models"""
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_embedding_response = MagicMock()
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
with mock_client_context(mixin, mock_client):
# Test chat completion
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test completion
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
)
mock_client.completions.create.assert_called_once()
# Test embeddings
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
)
mock_client.embeddings.create.assert_called_once()
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
"""Test that all inference methods fail with disallowed models"""
mixin.config.allowed_models = ["gpt-4"]
mock_client = MagicMock()
with mock_client_context(mixin, mock_client):
# Test chat completion with disallowed model
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
# Test completion with disallowed model
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
)
# Test embeddings with disallowed model
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
)
mock_client.chat.completions.create.assert_not_called()
mock_client.completions.create.assert_not_called()
mock_client.embeddings.create.assert_not_called()
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
# Test with None (no restrictions)
assert mixin.config.allowed_models is None
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
with mock_client_context(mixin, mock_client):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test with empty list (blocks all models)
mixin.config.allowed_models = []
with mock_client_context(mixin, mock_client):
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)