mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into auto_instrument_1
This commit is contained in:
commit
d8b883be41
4 changed files with 150 additions and 15 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
# These owners will be the default owners for everything in
|
# These owners will be the default owners for everything in
|
||||||
# the repo. Unless a later match takes precedence,
|
# the repo. Unless a later match takes precedence,
|
||||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 @franciscojavierarceo
|
* @ashwinb @raghotham @ehhuang @leseb @bbrowning @mattf @franciscojavierarceo
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ TEST_PATTERN=""
|
||||||
INFERENCE_MODE="replay"
|
INFERENCE_MODE="replay"
|
||||||
EXTRA_PARAMS=""
|
EXTRA_PARAMS=""
|
||||||
COLLECT_ONLY=false
|
COLLECT_ONLY=false
|
||||||
|
TYPESCRIPT_ONLY=false
|
||||||
|
|
||||||
# Function to display usage
|
# Function to display usage
|
||||||
usage() {
|
usage() {
|
||||||
|
|
@ -34,6 +35,7 @@ Options:
|
||||||
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
||||||
--pattern STRING Regex pattern to pass to pytest -k
|
--pattern STRING Regex pattern to pass to pytest -k
|
||||||
--collect-only Collect tests only without running them (skips server startup)
|
--collect-only Collect tests only without running them (skips server startup)
|
||||||
|
--typescript-only Skip Python tests and run only TypeScript client tests
|
||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
||||||
Suites are defined in tests/integration/suites.py and define which tests to run.
|
Suites are defined in tests/integration/suites.py and define which tests to run.
|
||||||
|
|
@ -90,6 +92,10 @@ while [[ $# -gt 0 ]]; do
|
||||||
COLLECT_ONLY=true
|
COLLECT_ONLY=true
|
||||||
shift
|
shift
|
||||||
;;
|
;;
|
||||||
|
--typescript-only)
|
||||||
|
TYPESCRIPT_ONLY=true
|
||||||
|
shift
|
||||||
|
;;
|
||||||
--help)
|
--help)
|
||||||
usage
|
usage
|
||||||
exit 0
|
exit 0
|
||||||
|
|
@ -552,7 +558,9 @@ if [[ -n "$STACK_CONFIG" ]]; then
|
||||||
STACK_CONFIG_ARG="--stack-config=$STACK_CONFIG"
|
STACK_CONFIG_ARG="--stack-config=$STACK_CONFIG"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
pytest -s -v $PYTEST_TARGET \
|
# Run Python tests unless typescript-only mode
|
||||||
|
if [[ "$TYPESCRIPT_ONLY" == "false" ]]; then
|
||||||
|
pytest -s -v $PYTEST_TARGET \
|
||||||
$STACK_CONFIG_ARG \
|
$STACK_CONFIG_ARG \
|
||||||
--inference-mode="$INFERENCE_MODE" \
|
--inference-mode="$INFERENCE_MODE" \
|
||||||
-k "$PYTEST_PATTERN" \
|
-k "$PYTEST_PATTERN" \
|
||||||
|
|
@ -561,7 +569,12 @@ pytest -s -v $PYTEST_TARGET \
|
||||||
--embedding-model=sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
|
--embedding-model=sentence-transformers/nomic-ai/nomic-embed-text-v1.5 \
|
||||||
--color=yes $EXTRA_PARAMS \
|
--color=yes $EXTRA_PARAMS \
|
||||||
--capture=tee-sys
|
--capture=tee-sys
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
|
else
|
||||||
|
echo "Skipping Python tests (--typescript-only mode)"
|
||||||
|
exit_code=0
|
||||||
|
fi
|
||||||
|
|
||||||
set +x
|
set +x
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -213,6 +213,19 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
|
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
def _validate_model_allowed(self, provider_model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate that the model is in the allowed_models list if configured.
|
||||||
|
|
||||||
|
:param provider_model_id: The provider-specific model ID to validate
|
||||||
|
:raises ValueError: If the model is not in the allowed_models list
|
||||||
|
"""
|
||||||
|
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{provider_model_id}' is not in the allowed models list. "
|
||||||
|
f"Allowed models: {self.config.allowed_models}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model: str) -> str:
|
async def _get_provider_model_id(self, model: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get the provider-specific model ID from the model store.
|
Get the provider-specific model ID from the model store.
|
||||||
|
|
@ -259,8 +272,11 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
Direct OpenAI completion API call.
|
Direct OpenAI completion API call.
|
||||||
"""
|
"""
|
||||||
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||||
|
provider_model_id = await self._get_provider_model_id(params.model)
|
||||||
|
self._validate_model_allowed(provider_model_id)
|
||||||
|
|
||||||
completion_kwargs = await prepare_openai_completion_params(
|
completion_kwargs = await prepare_openai_completion_params(
|
||||||
model=await self._get_provider_model_id(params.model),
|
model=provider_model_id,
|
||||||
prompt=params.prompt,
|
prompt=params.prompt,
|
||||||
best_of=params.best_of,
|
best_of=params.best_of,
|
||||||
echo=params.echo,
|
echo=params.echo,
|
||||||
|
|
@ -292,6 +308,9 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
Direct OpenAI chat completion API call.
|
Direct OpenAI chat completion API call.
|
||||||
"""
|
"""
|
||||||
|
provider_model_id = await self._get_provider_model_id(params.model)
|
||||||
|
self._validate_model_allowed(provider_model_id)
|
||||||
|
|
||||||
messages = params.messages
|
messages = params.messages
|
||||||
|
|
||||||
if self.download_images:
|
if self.download_images:
|
||||||
|
|
@ -313,7 +332,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
messages = [await _localize_image_url(m) for m in messages]
|
messages = [await _localize_image_url(m) for m in messages]
|
||||||
|
|
||||||
request_params = await prepare_openai_completion_params(
|
request_params = await prepare_openai_completion_params(
|
||||||
model=await self._get_provider_model_id(params.model),
|
model=provider_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=params.frequency_penalty,
|
frequency_penalty=params.frequency_penalty,
|
||||||
function_call=params.function_call,
|
function_call=params.function_call,
|
||||||
|
|
@ -351,10 +370,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
Direct OpenAI embeddings API call.
|
Direct OpenAI embeddings API call.
|
||||||
"""
|
"""
|
||||||
|
provider_model_id = await self._get_provider_model_id(params.model)
|
||||||
|
self._validate_model_allowed(provider_model_id)
|
||||||
|
|
||||||
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
||||||
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
|
||||||
request_params: dict[str, Any] = {
|
request_params: dict[str, Any] = {
|
||||||
"model": await self._get_provider_model_id(params.model),
|
"model": provider_model_id,
|
||||||
"input": params.input,
|
"input": params.input,
|
||||||
}
|
}
|
||||||
if params.encoding_format is not None:
|
if params.encoding_format is not None:
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
|
||||||
from llama_stack.core.request_headers import request_provider_data_context
|
from llama_stack.core.request_headers import request_provider_data_context
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
from llama_stack_api import (
|
||||||
|
Model,
|
||||||
|
ModelType,
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody,
|
||||||
|
OpenAICompletionRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMixinImpl(OpenAIMixin):
|
class OpenAIMixinImpl(OpenAIMixin):
|
||||||
|
|
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
|
||||||
error_message = str(exc_info.value)
|
error_message = str(exc_info.value)
|
||||||
assert "test_api_key" in error_message
|
assert "test_api_key" in error_message
|
||||||
assert "x-llamastack-provider-data" in error_message
|
assert "x-llamastack-provider-data" in error_message
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIMixinAllowedModelsInference:
|
||||||
|
"""Test cases for allowed_models enforcement during inference requests"""
|
||||||
|
|
||||||
|
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
|
||||||
|
"""Test that all inference methods succeed with allowed models"""
|
||||||
|
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||||
|
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||||
|
mock_embedding_response = MagicMock()
|
||||||
|
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||||
|
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client):
|
||||||
|
# Test chat completion
|
||||||
|
await mixin.openai_chat_completion(
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
# Test completion
|
||||||
|
await mixin.openai_completion(
|
||||||
|
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
|
||||||
|
)
|
||||||
|
mock_client.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
# Test embeddings
|
||||||
|
await mixin.openai_embeddings(
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
|
||||||
|
)
|
||||||
|
mock_client.embeddings.create.assert_called_once()
|
||||||
|
|
||||||
|
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
|
||||||
|
"""Test that all inference methods fail with disallowed models"""
|
||||||
|
mixin.config.allowed_models = ["gpt-4"]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client):
|
||||||
|
# Test chat completion with disallowed model
|
||||||
|
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
|
||||||
|
await mixin.openai_chat_completion(
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test completion with disallowed model
|
||||||
|
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
|
||||||
|
await mixin.openai_completion(
|
||||||
|
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test embeddings with disallowed model
|
||||||
|
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
|
||||||
|
await mixin.openai_embeddings(
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.chat.completions.create.assert_not_called()
|
||||||
|
mock_client.completions.create.assert_not_called()
|
||||||
|
mock_client.embeddings.create.assert_not_called()
|
||||||
|
|
||||||
|
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
|
||||||
|
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
|
||||||
|
# Test with None (no restrictions)
|
||||||
|
assert mixin.config.allowed_models is None
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with mock_client_context(mixin, mock_client):
|
||||||
|
await mixin.openai_chat_completion(
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
# Test with empty list (blocks all models)
|
||||||
|
mixin.config.allowed_models = []
|
||||||
|
with mock_client_context(mixin, mock_client):
|
||||||
|
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
|
||||||
|
await mixin.openai_chat_completion(
|
||||||
|
OpenAIChatCompletionRequestWithExtraBody(
|
||||||
|
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue