mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 10:10:36 +00:00
Merge branch 'main' into auto_instrument_1
This commit is contained in:
commit
d8b883be41
4 changed files with 150 additions and 15 deletions
|
|
@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
|
|
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModelsInference:
|
||||
"""Test cases for allowed_models enforcement during inference requests"""
|
||||
|
||||
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods succeed with allowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_embedding_response = MagicMock()
|
||||
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test completion
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
|
||||
)
|
||||
mock_client.completions.create.assert_called_once()
|
||||
|
||||
# Test embeddings
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
|
||||
)
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods fail with disallowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
||||
# Test completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
|
||||
)
|
||||
|
||||
# Test embeddings with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
|
||||
)
|
||||
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
mock_client.completions.create.assert_not_called()
|
||||
mock_client.embeddings.create.assert_not_called()
|
||||
|
||||
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
|
||||
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
|
||||
# Test with None (no restrictions)
|
||||
assert mixin.config.allowed_models is None
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test with empty list (blocks all models)
|
||||
mixin.config.allowed_models = []
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue