mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
Merge branch 'main' into auto_instrument_1
This commit is contained in:
commit
d8b883be41
4 changed files with 150 additions and 15 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 @franciscojavierarceo
|
||||
* @ashwinb @raghotham @ehhuang @leseb @bbrowning @mattf @franciscojavierarceo
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ TEST_PATTERN=""
|
|||
INFERENCE_MODE="replay"
|
||||
EXTRA_PARAMS=""
|
||||
COLLECT_ONLY=false
|
||||
TYPESCRIPT_ONLY=false
|
||||
|
||||
# Function to display usage
|
||||
usage() {
|
||||
|
|
@ -34,6 +35,7 @@ Options:
|
|||
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
||||
--pattern STRING Regex pattern to pass to pytest -k
|
||||
--collect-only Collect tests only without running them (skips server startup)
|
||||
--typescript-only Skip Python tests and run only TypeScript client tests
|
||||
--help Show this help message
|
||||
|
||||
Suites are defined in tests/integration/suites.py and define which tests to run.
|
||||
|
|
@ -90,6 +92,10 @@ while [[ $# -gt 0 ]]; do
|
|||
COLLECT_ONLY=true
|
||||
shift
|
||||
;;
|
||||
--typescript-only)
|
||||
TYPESCRIPT_ONLY=true
|
||||
shift
|
||||
;;
|
||||
--help)
|
||||
usage
|
||||
exit 0
|
||||
|
|
@ -552,6 +558,8 @@ if [[ -n "$STACK_CONFIG" ]]; then
|
|||
STACK_CONFIG_ARG="--stack-config=$STACK_CONFIG"
|
||||
fi
|
||||
|
||||
# Run Python tests unless typescript-only mode
|
||||
if [[ "$TYPESCRIPT_ONLY" == "false" ]]; then
|
||||
pytest -s -v $PYTEST_TARGET \
|
||||
$STACK_CONFIG_ARG \
|
||||
--inference-mode="$INFERENCE_MODE" \
|
||||
|
|
@ -562,6 +570,11 @@ pytest -s -v $PYTEST_TARGET \
|
|||
--color=yes $EXTRA_PARAMS \
|
||||
--capture=tee-sys
|
||||
exit_code=$?
|
||||
else
|
||||
echo "Skipping Python tests (--typescript-only mode)"
|
||||
exit_code=0
|
||||
fi
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
|
|
|
|||
|
|
@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
|
||||
return api_key
|
||||
|
||||
def _validate_model_allowed(self, provider_model_id: str) -> None:
|
||||
"""
|
||||
Validate that the model is in the allowed_models list if configured.
|
||||
|
||||
:param provider_model_id: The provider-specific model ID to validate
|
||||
:raises ValueError: If the model is not in the allowed_models list
|
||||
"""
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
raise ValueError(
|
||||
f"Model '{provider_model_id}' is not in the allowed models list. "
|
||||
f"Allowed models: {self.config.allowed_models}"
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Get the provider-specific model ID from the model store.
|
||||
|
|
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
Direct OpenAI completion API call.
|
||||
"""
|
||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
completion_kwargs = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI chat completion API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
messages = params.messages
|
||||
|
||||
if self.download_images:
|
||||
|
|
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
messages = [await _localize_image_url(m) for m in messages]
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
"""
|
||||
Direct OpenAI embeddings API call.
|
||||
"""
|
||||
provider_model_id = await self._get_provider_model_id(params.model)
|
||||
self._validate_model_allowed(provider_model_id)
|
||||
|
||||
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
||||
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
||||
request_params: dict[str, Any] = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"model": provider_model_id,
|
||||
"input": params.input,
|
||||
}
|
||||
if params.encoding_format is not None:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
|
|
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModelsInference:
|
||||
"""Test cases for allowed_models enforcement during inference requests"""
|
||||
|
||||
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods succeed with allowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_embedding_response = MagicMock()
|
||||
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test completion
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
|
||||
)
|
||||
mock_client.completions.create.assert_called_once()
|
||||
|
||||
# Test embeddings
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
|
||||
)
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods fail with disallowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
||||
# Test completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
|
||||
)
|
||||
|
||||
# Test embeddings with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
|
||||
)
|
||||
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
mock_client.completions.create.assert_not_called()
|
||||
mock_client.embeddings.create.assert_not_called()
|
||||
|
||||
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
|
||||
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
|
||||
# Test with None (no restrictions)
|
||||
assert mixin.config.allowed_models is None
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test with empty list (blocks all models)
|
||||
mixin.config.allowed_models = []
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue