Merge branch 'main' into dead_code_removal

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:21:36 -07:00 committed by GitHub
commit 9886520b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
927 changed files with 171924 additions and 102933 deletions

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationCreateRequest,
ConversationItem,
ConversationItemList,
)
def test_conversation_create_request_defaults():
request = ConversationCreateRequest()
assert request.items == []
assert request.metadata == {}
def test_conversation_model_defaults():
conversation = Conversation(
id="conv_123456789",
created_at=1234567890,
metadata=None,
object="conversation",
)
assert conversation.id == "conv_123456789"
assert conversation.object == "conversation"
assert conversation.metadata is None
def test_openai_client_compatibility():
from openai.types.conversations.message import Message
from pydantic import TypeAdapter
openai_message = Message(
id="msg_123",
content=[{"type": "input_text", "text": "Hello"}],
role="user",
status="in_progress",
type="message",
object="message",
)
adapter = TypeAdapter(ConversationItem)
validated_item = adapter.validate_python(openai_message.model_dump())
assert validated_item.id == "msg_123"
assert validated_item.type == "message"
def test_conversation_item_list():
item_list = ConversationItemList(data=[])
assert item_list.object == "list"
assert item_list.data == []
assert item_list.first_id is None
assert item_list.last_id is None
assert item_list.has_more is False

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
import pytest
from openai.types.conversations.conversation import Conversation as OpenAIConversation
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
from pydantic import TypeAdapter
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentText,
OpenAIResponseMessage,
)
from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.fixture
async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
async def test_conversation_lifecycle(service):
conversation = await service.create_conversation(metadata={"test": "data"})
assert conversation.id.startswith("conv_")
assert conversation.metadata == {"test": "data"}
retrieved = await service.get_conversation(conversation.id)
assert retrieved.id == conversation.id
deleted = await service.openai_delete_conversation(conversation.id)
assert deleted.id == conversation.id
async def test_conversation_items(service):
conversation = await service.create_conversation()
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test123",
status="completed",
)
]
item_list = await service.add_items(conversation.id, items)
assert len(item_list.data) == 1
assert item_list.data[0].id == "msg_test123"
items = await service.list(conversation.id)
assert len(items.data) == 1
async def test_invalid_conversation_id(service):
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
await service._get_validated_conversation("invalid_id")
async def test_empty_parameter_validation(service):
with pytest.raises(ValueError, match="Expected a non-empty value"):
await service.retrieve("", "item_123")
async def test_openai_type_compatibility(service):
conversation = await service.create_conversation(metadata={"test": "value"})
conversation_dict = conversation.model_dump()
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
for attr in ["id", "object", "created_at", "metadata"]:
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test456",
status="completed",
)
]
item_list = await service.add_items(conversation.id, items)
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
assert hasattr(item_list, attr)
assert item_list.object == "list"
items = await service.list(conversation.id)
item = await service.retrieve(conversation.id, items.data[0].id)
item_dict = item.model_dump()
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
openai_item_adapter.validate_python(item_dict)
async def test_policy_configuration():
from llama_stack.core.access_control.datatypes import Action, Scope
from llama_stack.core.datatypes import AccessRule
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations_policy.db"
restrictive_policy = [
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
)
service = ConversationServiceImpl(config, {})
await service.initialize()
assert service.policy == restrictive_policy
assert len(service.policy) == 1
assert service.policy[0].forbid is not None

View file

@ -390,3 +390,467 @@ pip_packages:
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""
class TestGetExternalProvidersFromModule:
"""Test suite for installing external providers from module."""
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="no_module",
provider_type="no_module",
config={},
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
def test_stackrunconfig_with_version_spec(self, mock_providers):
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="versioned_test",
config_class="versioned_test.config.VersionedTestConfig",
module="versioned_test==1.0.0",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
def import_side_effect(name):
if name == "versioned_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="versioned",
provider_type="versioned_test",
config={},
module="versioned_test==1.0.0",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that BuildConfig does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="build_test",
module="build_test==1.0.0",
)
]
},
),
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "build_test" in result[Api.inference]
provider = result[Api.inference]["build_test"]
assert provider.module == "build_test==1.0.0"
assert provider.is_external is True
assert provider.config_class == ""
assert provider.api == Api.inference
def test_buildconfig_multiple_providers(self, mock_providers):
"""Test BuildConfig with multiple providers for the same API."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(provider_type="provider1", module="provider1"),
BuildProvider(provider_type="provider2", module="provider2"),
]
},
),
)
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
mock_import.assert_not_called()
assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
dist_spec = DistributionSpec(
description="test distribution",
providers={
"inference": [
BuildProvider(
provider_type="dist_test",
module="dist_test==2.0.0",
)
]
},
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "dist_test" in result[Api.inference]
provider = result[Api.inference]["dist_test"]
assert provider.module == "dist_test==2.0.0"
assert provider.is_external is True
assert provider.config_class == ""
def test_list_return_from_get_provider_spec(self, mock_providers):
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="list_test",
config_class="list_test.config.Config1",
module="list_test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="list_test_remote",
config_class="list_test.config.Config2",
module="list_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "list_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="list_test",
provider_type="list_test",
config={},
module="list_test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "list_test" in result[Api.inference]
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
def test_list_return_filters_by_provider_type(self, mock_providers):
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="wanted",
config_class="test.Config1",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="unwanted",
config_class="test.Config2",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="wanted",
provider_type="wanted",
config={},
module="test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "wanted" in result[Api.inference]
assert "unwanted" not in result[Api.inference]
def test_list_return_adds_multiple_provider_types(self, mock_providers):
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
# Module returns both inline and remote variants
spec1 = ProviderSpec(
api=Api.inference,
provider_type="remote::ollama",
config_class="test.RemoteConfig",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="inline::ollama",
config_class="test.InlineConfig",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="remote_ollama",
provider_type="remote::ollama",
config={},
module="test",
),
Provider(
provider_id="inline_ollama",
provider_type="inline::ollama",
config={},
module="test",
),
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference]
assert "inline::ollama" in result[Api.inference]
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
if name == "missing_module.provider":
raise ModuleNotFoundError(name)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="missing",
provider_type="missing",
config={},
module="missing_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "get_provider_spec not found" in str(exc_info.value)
def test_generic_exception_is_raised(self, mock_providers):
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
raise RuntimeError("Something went wrong")
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
def import_side_effect(name):
if name == "error_module.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="error",
provider_type="error",
config={},
module="error_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "Something went wrong" in str(exc_info.value)
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
config = StackRunConfig(
image_name="test_image",
providers={},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should return registry unchanged
assert result == registry
assert len(result[Api.inference]) == 0
def test_multiple_apis_with_providers(self, mock_providers):
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
inference_spec = ProviderSpec(
api=Api.inference,
provider_type="inf_test",
config_class="inf.Config",
module="inf_test",
)
safety_spec = ProviderSpec(
api=Api.safety,
provider_type="safe_test",
config_class="safe.Config",
module="safe_test",
)
def import_side_effect(name):
if name == "inf_test.provider":
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
elif name == "safe_test.provider":
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="inf",
provider_type="inf_test",
config={},
module="inf_test",
)
],
"safety": [
Provider(
provider_id="safe",
provider_type="safe_test",
config={},
module="safe_test",
)
],
},
)
registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety]

View file

@ -131,10 +131,6 @@ class TestInferenceRecording:
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test directory creation
assert storage.test_dir.exists()
assert storage.responses_dir.exists()
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
@ -174,7 +170,8 @@ class TestInferenceRecording:
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
assert storage.responses_dir.exists()
dir = storage._get_test_dir()
assert dir.exists()
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""

View file

@ -22,7 +22,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
@ -45,7 +44,10 @@ from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -499,13 +501,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response."""
input = await openai_responses_impl._prepend_previous_response("fake_input", None)
assert input == "fake_input"
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response."""
@ -520,7 +515,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
role="assistant",
)
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -528,10 +523,11 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="fake_previous_input")],
)
mock_responses_store.get_response_object.return_value = previous_response
input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123")
input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response)
assert len(input) == 3
# Check for previous input
@ -562,7 +558,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -570,11 +566,12 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -609,7 +606,7 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -617,11 +614,12 @@ async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mo
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[OpenAIUserMessageParam(content="test input")],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
input = await openai_responses_impl._prepend_previous_response(input_messages, response)
assert len(input) == 4
# Check for previous input
@ -725,7 +723,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
@ -733,6 +731,10 @@ async def test_create_openai_response_with_instructions_and_previous_response(
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
)
mock_responses_store.get_response_object.return_value = response
@ -818,7 +820,7 @@ async def test_responses_store_list_input_items_logic():
OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"),
]
response_with_input = OpenAIResponseObjectWithInput(
response_with_input = _OpenAIResponseObjectWithInputAndMessages(
id="resp_123",
model="test_model",
created_at=1234567890,
@ -827,6 +829,7 @@ async def test_responses_store_list_input_items_logic():
output=[],
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
input=input_items,
messages=[OpenAIUserMessageParam(content="First message")],
)
# Mock the get_response_object method to return our test data
@ -887,7 +890,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
rather than just the original input when previous_response_id is provided."""
# Setup - Create a previous response that should be included in the stored input
previous_response = OpenAIResponseObjectWithInput(
previous_response = _OpenAIResponseObjectWithInputAndMessages(
id="resp-previous-123",
object="response",
created_at=1234567890,
@ -906,6 +909,10 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")],
)
],
messages=[
OpenAIUserMessageParam(content="What is 2+2?"),
OpenAIAssistantMessageParam(content="2+2 equals 4."),
],
)
mock_responses_store.get_response_object.return_value = previous_response

View file

@ -7,6 +7,8 @@
import json
from unittest.mock import MagicMock
import pytest
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
@ -18,72 +20,41 @@ from llama_stack.providers.remote.inference.together.config import TogetherImplC
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
def test_openai_provider_openai_client_caching():
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
GroqConfig,
GroqInferenceAdapter,
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
),
(
OpenAIConfig,
OpenAIInferenceAdapter,
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
),
(
TogetherImplConfig,
TogetherInferenceAdapter,
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
(
LlamaCompatConfig,
LlamaCompatInferenceAdapter,
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
),
],
)
def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
)
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter.client
assert openai_client.api_key == api_key
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
together_client = inference_adapter._get_client()
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key

View file

@ -18,7 +18,7 @@ class TestOpenAIBaseURLConfig:
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://api.openai.com/v1"
@ -27,7 +27,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == custom_url
@ -39,7 +39,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
assert adapter.get_base_url() == "https://env.openai.com/v1"
@ -49,7 +49,7 @@ class TestOpenAIBaseURLConfig:
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Config should take precedence over environment variable
@ -60,7 +60,7 @@ class TestOpenAIBaseURLConfig:
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
@ -80,7 +80,7 @@ class TestOpenAIBaseURLConfig:
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method
@ -122,7 +122,7 @@ class TestOpenAIBaseURLConfig:
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
adapter = OpenAIInferenceAdapter(config=config)
adapter.provider_data_api_key_field = None # Disable provider data for this test
# Mock the get_api_key method

View file

@ -5,49 +5,21 @@
# the root directory of this source tree.
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChoiceChunk,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
SystemMessage,
ToolChoice,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
_process_vllm_chat_completion_stream_response,
)
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
# These are unit test for the remote vllm provider
# implementation. This should only contain tests which are specific to
@ -60,37 +32,15 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
# -v -s --tb=short --disable-warnings
@pytest.fixture(scope="module")
def mock_openai_models_list():
with patch("openai.resources.models.AsyncModels.list") as mock_list:
yield mock_list
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
# Mock the __provider_spec__ attribute that would normally be set by the resolver
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
await inference_adapter.initialize()
return inference_adapter
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
mock_openai_models_list.return_value = mock_openai_models()
foo_model = Model(identifier="foo", provider_resource_id="foo", provider_id="vllm-inference")
await vllm_inference_adapter.register_model(foo_model)
mock_openai_models_list.assert_called()
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -99,463 +49,24 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
with patch.object(vllm_inference_adapter, "_nonstream_chat_completion") as mock_nonstream_completion:
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.chat_completion(
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
tool_choice=ToolChoice.auto.value,
)
mock_nonstream_completion.assert_called()
request = mock_nonstream_completion.call_args.args[0]
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
assert request.tool_config.tool_choice == ToolChoice.none
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the client property to avoid instantiating a real AsyncOpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
mock_create_client.return_value = mock_client
# Mock the model to return a proper provider_resource_id
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="foo",
tool_name="knowledge_search",
arguments='{"query": "How many?"}',
)
],
),
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
]
await vllm_inference_adapter.chat_completion(
"mock-model",
messages,
stream=False,
tools=[],
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
}
]
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
tool call response that didn't call any tools. Previously we would
emit chunks with spurious ToolCallParseStatus.succeeded or
ToolCallParseStatus.failed when processing chunks that didn't
actually make any tool calls.
"""
async def mock_stream():
delta = OpenAIChoiceDelta(content="", tool_calls=None)
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoiceChunk(
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
completion request. This can happen when there's an error thrown
in vLLM for example.
"""
async def mock_stream():
choices = []
mock_chunk = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=choices,
)
for chunk in [mock_chunk]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
model="test_model",
messages=[UserMessage(content="test")],
)
params = await vllm_inference_adapter._get_params(request)
assert "tools" not in params
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
contains the finish reason (i.e., the last one).
We want to make sure the tool call is executed in this case, and the parameters are passed correctly.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": None,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {
"name": None,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": "tool_calls",
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
finish reason.
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = json.dumps(mock_tool_arguments)
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments_str
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
"""
mock_tool_name = "mock_tool"
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": "",
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == "{}"
assert call_args.kwargs["tool_choice"] == ToolChoice.none.value
async def test_health_status_success(vllm_inference_adapter):
@ -688,96 +199,30 @@ async def test_should_refresh_models():
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config1)
adapter1 = VLLMInferenceAdapter(config=config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config2)
adapter2 = VLLMInferenceAdapter(config=config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config3)
adapter3 = VLLMInferenceAdapter(config=config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config4)
adapter4 = VLLMInferenceAdapter(config=config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config5)
adapter5 = VLLMInferenceAdapter(config=config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
"""
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
This ensures that dynamic provider data (like API tokens) can be passed through context.
Note: The base URL is always taken from config.url, not from provider data.
"""
# Mock the AsyncOpenAI class to capture provider data
with (
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
):
mock_client = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
mock_openai_class.return_value = mock_client
# Mock provider data to return test data
mock_provider_data = MagicMock()
mock_provider_data.vllm_api_token = "test-token-123"
mock_provider_data.vllm_url = "http://test-server:8000/v1"
mock_get_provider_data.return_value = mock_provider_data
# Mock the model
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
try:
# Execute chat completion
await vllm_inference_adapter.chat_completion(
"test-model",
[UserMessage(content="Hello")],
stream=False,
tools=None,
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
# Verify that ALL client calls were made with the correct parameters
calls = mock_openai_class.call_args_list
incorrect_calls = []
for i, call in enumerate(calls):
api_key = call[1]["api_key"]
base_url = call[1]["base_url"]
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
if incorrect_calls:
error_msg = (
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
)
for incorrect_call in incorrect_calls:
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
raise AssertionError(error_msg)
# Ensure at least one call was made
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
# Verify that chat completion was called
mock_client.chat.completions.create.assert_called_once()
finally:
# Clean up context
pass

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
import json
from collections.abc import Iterable
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
@ -13,6 +15,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -29,7 +32,7 @@ class OpenAIMixinImpl(OpenAIMixin):
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
"""Test implementation with embedding model metadata"""
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
@ -38,7 +41,8 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
@ -53,7 +57,8 @@ def mixin():
@pytest.fixture
def mixin_with_embeddings():
"""Create a test instance of OpenAIMixin with embedding model metadata"""
return OpenAIMixinWithEmbeddingsImpl()
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
@ -362,6 +367,124 @@ class TestOpenAIMixinAllowedModels:
assert not await mixin.check_model_availability("another-mock-model-id")
class TestOpenAIMixinModelRegistration:
"""Test cases for model registration functionality"""
async def test_register_model_success(self, mixin, mock_client_with_models, mock_client_context):
"""Test successful model registration when model is available"""
model = Model(
provider_id="test-provider",
provider_resource_id="some-mock-model-id",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.register_model(model)
assert result == model
assert result.provider_id == "test-provider"
assert result.provider_resource_id == "some-mock-model-id"
assert result.identifier == "test-model"
assert result.model_type == ModelType.llm
mock_client_with_models.models.list.assert_called_once()
async def test_register_model_not_available(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration failure when model is not available from provider"""
model = Model(
provider_id="test-provider",
provider_resource_id="non-existent-model",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
with pytest.raises(
ValueError, match="Model non-existent-model is not available from provider test-provider"
):
await mixin.register_model(model)
mock_client_with_models.models.list.assert_called_once()
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
"""Test model registration with allowed_models filtering"""
mixin.allowed_models = {"some-mock-model-id"}
# Test with allowed model
allowed_model = Model(
provider_id="test-provider",
provider_resource_id="some-mock-model-id",
identifier="allowed-model",
model_type=ModelType.llm,
)
# Test with disallowed model
disallowed_model = Model(
provider_id="test-provider",
provider_resource_id="final-mock-model-id",
identifier="disallowed-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_models):
result = await mixin.register_model(allowed_model)
assert result == allowed_model
with pytest.raises(
ValueError, match="Model final-mock-model-id is not available from provider test-provider"
):
await mixin.register_model(disallowed_model)
mock_client_with_models.models.list.assert_called_once()
async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context):
"""Test registration of embedding models with metadata"""
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_models = [mock_embedding_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
embedding_model = Model(
provider_id="test-provider",
provider_resource_id="text-embedding-3-small",
identifier="embedding-test",
model_type=ModelType.embedding,
)
with mock_client_context(mixin_with_embeddings, mock_client):
result = await mixin_with_embeddings.register_model(embedding_model)
assert result == embedding_model
assert result.model_type == ModelType.embedding
async def test_unregister_model(self, mixin):
"""Test model unregistration (should be no-op)"""
# unregister_model should not raise any exceptions and return None
result = await mixin.unregister_model("any-model-id")
assert result is None
async def test_should_refresh_models(self, mixin):
"""Test should_refresh_models method (should always return False)"""
result = await mixin.should_refresh_models()
assert result is False
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
"""Test that errors from provider API are properly propagated during registration"""
model = Model(
provider_id="test-provider",
provider_resource_id="some-model",
identifier="test-model",
model_type=ModelType.llm,
)
with mock_client_context(mixin, mock_client_with_exception):
# The exception from the API should be propagated
with pytest.raises(Exception, match="API Error"):
await mixin.register_model(model)
class ProviderDataValidator(BaseModel):
"""Validator for provider data in tests"""
@ -380,13 +503,145 @@ class OpenAIMixinWithProviderData(OpenAIMixinImpl):
return "default-base-url"
class CustomListProviderModelIdsImplementation(OpenAIMixinImpl):
"""Test implementation with custom list_provider_model_ids override"""
custom_model_ids: Any
async def list_provider_model_ids(self) -> Iterable[str]:
"""Return custom model IDs list"""
return self.custom_model_ids
class TestOpenAIMixinCustomListProviderModelIds:
"""Test cases for custom list_provider_model_ids() implementation functionality"""
@pytest.fixture
def custom_model_ids_list(self):
"""Create a list of custom model ID strings"""
return ["custom-model-1", "custom-model-2", "custom-embedding"]
@pytest.fixture
def config(self):
"""Create RemoteInferenceProviderConfig instance"""
return RemoteInferenceProviderConfig()
@pytest.fixture
def adapter(self, custom_model_ids_list, config):
"""Create mixin instance with custom list_provider_model_ids implementation"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list)
mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}}
return mixin
async def test_is_used(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() implementation is used instead of client.models.list()"""
result = await adapter.list_models()
assert result is not None
assert len(result) == 3
assert set(custom_model_ids_list) == {m.identifier for m in result}
async def test_populates_cache(self, adapter, custom_model_ids_list):
"""Test that custom list_provider_model_ids() results are cached"""
assert len(adapter._model_cache) == 0
await adapter.list_models()
assert set(custom_model_ids_list) == set(adapter._model_cache.keys())
async def test_respects_allowed_models(self, config):
"""Test that custom list_provider_model_ids() respects allowed_models filtering"""
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
)
mixin.allowed_models = ["model-1"]
result = await mixin.list_models()
assert result is not None
assert len(result) == 1
assert result[0].identifier == "model-1"
async def test_with_empty_list(self, config):
"""Test that custom list_provider_model_ids() handles empty list correctly"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[])
result = await mixin.list_models()
assert result is not None
assert len(result) == 0
assert len(mixin._model_cache) == 0
async def test_wrong_type_raises_error(self, config):
"""Test that list_provider_model_ids() returning unhashable items results in an error"""
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=["valid-model", ["nested", "list"]]
)
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=[{"key": "value"}, "valid-model"]
)
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0])
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None])
with pytest.raises(Exception, match="is not a string"):
await mixin.list_models()
async def test_non_iterable_raises_error(self, config):
"""Test that list_provider_model_ids() returning non-iterable type raises error"""
mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42)
with pytest.raises(
TypeError,
match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int",
):
await mixin.list_models()
async def test_accepts_various_iterables(self, config):
"""Test that list_provider_model_ids() accepts tuples, sets, generators, etc."""
tuples = CustomListProviderModelIdsImplementation(
config=config, custom_model_ids=("model-1", "model-2", "model-3")
)
result = await tuples.list_models()
assert result is not None
assert len(result) == 3
class GeneratorAdapter(OpenAIMixinImpl):
async def list_provider_model_ids(self) -> Iterable[str]:
def gen():
yield "gen-model-1"
yield "gen-model-2"
return gen()
mixin = GeneratorAdapter(config=config)
result = await mixin.list_models()
assert result is not None
assert len(result) == 2
sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"})
result = await sets.list_models()
assert result is not None
assert len(result) == 2
class TestOpenAIMixinProviderDataApiKey:
"""Test cases for provider_data_api_key_field functionality"""
@pytest.fixture
def mixin_with_provider_data_field(self):
"""Mixin instance with provider_data_api_key_field set"""
mixin_instance = OpenAIMixinWithProviderData()
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinWithProviderData(config=config)
# Mock provider_spec for provider data validation
mock_provider_spec = MagicMock()

View file

@ -26,13 +26,15 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, Mi
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector"])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
def vector_provider(request):
return request.param
@ -448,6 +450,71 @@ async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
await adapter.shutdown()
@pytest.fixture(scope="session")
def weaviate_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
return db_path
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
index = WeaviateIndex(client=client, collection_name="Testcollection")
await index.initialize()
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
client.close()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
vector_provider_dict = {
@ -457,6 +524,7 @@ def vector_io_adapter(vector_provider, request):
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -14,6 +14,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseObject,
)
from llama_stack.apis.inference import OpenAIMessageParam, OpenAIUserMessageParam
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -44,6 +45,11 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
)
def create_test_messages(content: str) -> list[OpenAIMessageParam]:
"""Helper to create test messages for chat completion."""
return [OpenAIUserMessageParam(content=content)]
async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
@ -65,7 +71,8 @@ async def test_responses_store_pagination_basic():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -111,7 +118,8 @@ async def test_responses_store_pagination_ascending():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -149,7 +157,8 @@ async def test_responses_store_pagination_with_model_filter():
for response_id, timestamp, model in test_data:
response = create_test_response_object(response_id, timestamp, model)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -199,7 +208,8 @@ async def test_responses_store_pagination_no_limit():
for response_id, timestamp in test_data:
response = create_test_response_object(response_id, timestamp)
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
await store.store_response_object(response, input_list)
messages = create_test_messages(f"Input for {response_id}")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -222,7 +232,8 @@ async def test_responses_store_get_response_object():
# Store a test response
response = create_test_response_object("test-resp", int(time.time()))
input_list = [create_test_response_input("Test input content", "input-test-resp")]
await store.store_response_object(response, input_list)
messages = create_test_messages("Test input content")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -255,7 +266,8 @@ async def test_responses_store_input_items_pagination():
create_test_response_input("Fourth input", "input-4"),
create_test_response_input("Fifth input", "input-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()
@ -335,7 +347,8 @@ async def test_responses_store_input_items_before_pagination():
create_test_response_input("Fourth input", "before-4"),
create_test_response_input("Fifth input", "before-5"),
]
await store.store_response_object(response, input_list)
messages = create_test_messages("First input")
await store.store_response_object(response, input_list, messages)
# Wait for all queued writes to complete
await store.flush()

View file

@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
assert {r["id"] for r in rows_after} == {1, 3}
async def test_batch_insert():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"batch_test",
{
"id": ColumnType.INTEGER,
"name": ColumnType.STRING,
"value": ColumnType.INTEGER,
},
)
batch_data = [
{"id": 1, "name": "first", "value": 10},
{"id": 2, "name": "second", "value": 20},
{"id": 3, "name": "third", "value": 30},
]
await store.insert("batch_test", batch_data)
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
assert result.data == batch_data
async def test_where_operator_edge_cases():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"