mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-10 21:34:36 +00:00
Merge branch 'main' into remove-deprecated-embeddings
This commit is contained in:
commit
5c44dcdf0e
770 changed files with 176834 additions and 27431 deletions
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
@ -105,6 +106,9 @@ class ScoringFunctionsImpl(Impl):
|
|||
async def register_scoring_function(self, scoring_fn):
|
||||
return scoring_fn
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str):
|
||||
return scoring_fn_id
|
||||
|
||||
|
||||
class BenchmarksImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -113,6 +117,9 @@ class BenchmarksImpl(Impl):
|
|||
async def register_benchmark(self, benchmark):
|
||||
return benchmark
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str):
|
||||
return benchmark_id
|
||||
|
||||
|
||||
class ToolGroupsImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -146,6 +153,20 @@ class VectorDBImpl(Impl):
|
|||
async def unregister_vector_db(self, vector_db_id: str):
|
||||
return vector_db_id
|
||||
|
||||
async def openai_create_vector_store(self, **kwargs):
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||
|
||||
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||
return VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=kwargs.get("name", vector_store_id),
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
|
||||
|
||||
async def test_models_routing_table(cached_disk_dist_registry):
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
|
@ -247,17 +268,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
assert vdb1.identifier in vector_db_ids
|
||||
assert vdb2.identifier in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
# Verify they have UUID-based identifiers
|
||||
assert vdb1.identifier.startswith("vs_")
|
||||
assert vdb2.identifier.startswith("vs_")
|
||||
|
||||
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
@ -312,6 +337,13 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
|
|||
assert "test-scoring-fn" in scoring_fn_ids
|
||||
assert "test-scoring-fn-2" in scoring_fn_ids
|
||||
|
||||
# Unregister scoring functions and verify listing
|
||||
for i in range(len(scoring_functions.data)):
|
||||
await table.unregister_scoring_function(scoring_functions.data[i].scoring_fn_id)
|
||||
|
||||
scoring_functions_list_after_deletion = await table.list_scoring_functions()
|
||||
assert len(scoring_functions_list_after_deletion.data) == 0
|
||||
|
||||
|
||||
async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
||||
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
|
||||
|
@ -329,6 +361,15 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
|
|||
benchmark_ids = {b.identifier for b in benchmarks.data}
|
||||
assert "test-benchmark" in benchmark_ids
|
||||
|
||||
# Unregister the benchmark and verify removal
|
||||
await table.unregister_benchmark(benchmark_id="test-benchmark")
|
||||
benchmarks_after = await table.list_benchmarks()
|
||||
assert len(benchmarks_after.data) == 0
|
||||
|
||||
# Unregistering a non-existent benchmark should raise a clear error
|
||||
with pytest.raises(ValueError, match="Benchmark 'dummy_benchmark' not found"):
|
||||
await table.unregister_benchmark(benchmark_id="dummy_benchmark")
|
||||
|
||||
|
||||
async def test_tool_groups_routing_table(cached_disk_dist_registry):
|
||||
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
|
||||
|
@ -605,3 +646,25 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
|||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
|
||||
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
|
||||
|
||||
exception_throwing_tool_groups_impl = ToolGroupsImpl()
|
||||
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
|
||||
|
||||
table = ToolGroupsRoutingTable(
|
||||
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
await table.register_tool_group(
|
||||
toolgroup_id="test-toolgroup-exceptions",
|
||||
provider_id="test_provider",
|
||||
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
|
||||
)
|
||||
|
||||
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
|
||||
|
||||
assert len(tools.data) == 0
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
# Unit tests for the routing tables vector_dbs
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
|
|||
class VectorDBImpl(Impl):
|
||||
def __init__(self):
|
||||
super().__init__(Api.vector_io)
|
||||
self.vector_stores = {}
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB):
|
||||
return vector_db
|
||||
|
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
|
|||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name=None,
|
||||
embedding_model=None,
|
||||
embedding_dimension=None,
|
||||
provider_id=None,
|
||||
provider_vector_db_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||
vector_store = VectorStoreObject(
|
||||
id=vector_store_id,
|
||||
name=name or vector_store_id,
|
||||
created_at=int(time.time()),
|
||||
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||
)
|
||||
self.vector_stores[vector_store_id] = vector_store
|
||||
return vector_store
|
||||
|
||||
async def openai_list_vector_stores(self, **kwargs):
|
||||
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||
|
||||
return VectorStoreListResponse(
|
||||
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||
)
|
||||
|
||||
|
||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||
n = 10
|
||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
|
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|||
)
|
||||
|
||||
# Register multiple vector databases and verify listing
|
||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
|
||||
assert len(vector_dbs.data) == 2
|
||||
assert len(vector_dbs.data) == len(vdb_dict)
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
assert "test-vectordb" in vector_db_ids
|
||||
assert "test-vectordb-2" in vector_db_ids
|
||||
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||
for k in vdb_dict:
|
||||
assert vdb_dict[k].identifier in vector_db_ids
|
||||
for k in vdb_dict:
|
||||
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||
n = 10
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
vdb_dict = {}
|
||||
for i in range(n):
|
||||
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
vector_store_ids = {v.id for v in vector_stores.data}
|
||||
|
||||
assert vector_db_ids == vector_store_ids, (
|
||||
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||
)
|
||||
|
||||
for vector_store in vector_stores.data:
|
||||
vector_db = await table.get_vector_db(vector_store.id)
|
||||
assert vector_store.name == vector_db.vector_db_name, (
|
||||
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||
)
|
||||
|
||||
for vector_db_id in vector_db_ids:
|
||||
await table.unregister_vector_db(vector_db_id)
|
||||
|
||||
assert len((await table.list_vector_dbs()).data) == 0
|
||||
|
||||
|
||||
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await m_table.initialize()
|
||||
await m_table.register_model(
|
||||
model_id="test-model",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 128},
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
|
||||
user_provided_id = "my-custom-vector-db"
|
||||
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||
|
||||
vector_stores = await impl.openai_list_vector_stores()
|
||||
assert len(vector_stores.data) == 1
|
||||
|
||||
vector_store = vector_stores.data[0]
|
||||
|
||||
assert vector_store.name == user_provided_id
|
||||
|
||||
assert vector_store.id.startswith("vs_")
|
||||
assert vector_store.id != user_provided_id
|
||||
|
||||
vector_dbs = await table.list_vector_dbs()
|
||||
assert len(vector_dbs.data) == 1
|
||||
assert vector_dbs.data[0].identifier == vector_store.id
|
||||
|
||||
await table.unregister_vector_db(vector_store.id)
|
||||
|
||||
|
||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||
impl = VectorDBImpl()
|
||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||
|
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
|
|||
|
||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
# Authorized reader
|
||||
with request_provider_data_context({}, authorized_user):
|
||||
|
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
|
|||
)
|
||||
|
||||
with request_provider_data_context({}, admin_user):
|
||||
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||
|
||||
read_methods = [
|
||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||
|
|
|
@ -12,7 +12,7 @@ import yaml
|
|||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
|
@ -66,10 +66,9 @@ def base_config(tmp_path):
|
|||
def provider_spec_yaml():
|
||||
"""Common provider spec YAML for testing."""
|
||||
return """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
module: test_provider
|
||||
adapter_type: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
module: test_provider
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
|
@ -152,6 +151,24 @@ class TestProviderRegistry:
|
|||
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
|
||||
assert registry[Api.inference]["test_provider"].api == Api.inference
|
||||
|
||||
def test_internal_apis_excluded(self):
|
||||
"""Test that internal APIs are excluded and APIs without provider registries are marked as internal."""
|
||||
import importlib
|
||||
|
||||
apis = providable_apis()
|
||||
|
||||
for internal_api in INTERNAL_APIS:
|
||||
assert internal_api not in apis, f"Internal API {internal_api} should not be in providable_apis"
|
||||
|
||||
for api in apis:
|
||||
module_name = f"llama_stack.providers.registry.{api.name.lower()}"
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except ImportError as err:
|
||||
raise AssertionError(
|
||||
f"API {api} is in providable_apis but has no provider registry module ({module_name})"
|
||||
) from err
|
||||
|
||||
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
|
||||
"""Test loading external remote providers from YAML files."""
|
||||
remote_dir, _ = api_directories
|
||||
|
@ -164,9 +181,9 @@ class TestProviderRegistry:
|
|||
assert Api.inference in registry
|
||||
assert "remote::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["remote::test_provider"]
|
||||
assert provider.adapter.adapter_type == "test_provider"
|
||||
assert provider.adapter.module == "test_provider"
|
||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert provider.adapter_type == "test_provider"
|
||||
assert provider.module == "test_provider"
|
||||
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert Api.safety in provider.api_dependencies
|
||||
|
||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||
|
@ -228,8 +245,7 @@ class TestProviderRegistry:
|
|||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||
remote_dir, _ = api_directories
|
||||
malformed_spec = """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
adapter_type: test_provider
|
||||
# Missing required fields
|
||||
api_dependencies:
|
||||
- safety
|
||||
|
@ -252,7 +268,7 @@ pip_packages:
|
|||
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
get_provider_registry(base_config)
|
||||
assert "config_class" in str(exc_info.value)
|
||||
|
||||
|
|
|
@ -6,16 +6,18 @@
|
|||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from openai import AsyncOpenAI
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
from openai.types.model import Model as OpenAIModel
|
||||
|
||||
# Import the real Pydantic response types instead of using Mocks
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
|
@ -153,24 +155,22 @@ class TestInferenceRecording:
|
|||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
@ -178,40 +178,74 @@ class TestInferenceRecording:
|
|||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
mock_create_patch.assert_not_called()
|
||||
# Verify the original method was NOT called
|
||||
client.chat.completions._post.assert_not_called()
|
||||
|
||||
async def test_replay_mode_models(self, temp_storage_dir):
|
||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||
|
||||
async def _async_iterator(models):
|
||||
for model in models:
|
||||
yield model
|
||||
|
||||
models = [
|
||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||
OpenAIModel(id="bar", created=2, object="model", owned_by="test"),
|
||||
]
|
||||
|
||||
expected_ids = {m.id for m in models}
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode_models"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# record the call
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_called_once()
|
||||
|
||||
# replay the call
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||
assert {m.id async for m in client.models.list()} == expected_ids
|
||||
client.models._get_api_list.assert_not_called()
|
||||
|
||||
async def test_replay_missing_recording(self, temp_storage_dir):
|
||||
"""Test that replay mode fails when no recording is found."""
|
||||
|
@ -228,36 +262,110 @@ class TestInferenceRecording:
|
|||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_embeddings_response
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
)
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
client.embeddings._post.assert_called_once()
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
dimensions=NOT_GIVEN,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model="nomic-embed-text", input=["Hello world", "Test embedding"]
|
||||
)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
mock_create_patch.assert_not_called()
|
||||
# Verify original method was not called
|
||||
client.embeddings._post.assert_not_called()
|
||||
|
||||
async def test_completions_recording(self, temp_storage_dir):
|
||||
real_completions_response = OpenAICompletion(
|
||||
id="test_completion",
|
||||
object="text_completion",
|
||||
created=1234567890,
|
||||
model="llama3.2:3b",
|
||||
choices=[
|
||||
{
|
||||
"text": "Hello! I'm doing well, thank you for asking.",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Record
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Replay
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
@ -266,7 +374,7 @@ class TestInferenceRecording:
|
|||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.LIVE):
|
||||
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
|
|
|
@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
|
|||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
class MockStack:
|
||||
def __init__(self, config, custom_provider_registry=None):
|
||||
self.impls = mock_impls
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
|
5
tests/unit/prompts/prompts/__init__.py
Normal file
5
tests/unit/prompts/prompts/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
30
tests/unit/prompts/prompts/conftest.py
Normal file
30
tests/unit/prompts/prompts/conftest.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_prompt_store(tmp_path_factory):
|
||||
unique_id = f"prompt_store_{random.randint(1, 1000000)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
mock_run_config = StackRunConfig(image_name="test-distribution", apis=[], providers={})
|
||||
config = PromptServiceConfig(run_config=mock_run_config)
|
||||
store = PromptServiceImpl(config, deps={})
|
||||
|
||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||
|
||||
yield store
|
144
tests/unit/prompts/prompts/test_prompts.py
Normal file
144
tests/unit/prompts/prompts/test_prompts.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPrompts:
|
||||
async def test_create_and_get_prompt(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("Hello world!", ["name"])
|
||||
assert prompt.prompt == "Hello world!"
|
||||
assert prompt.version == 1
|
||||
assert prompt.prompt_id.startswith("pmpt_")
|
||||
assert prompt.variables == ["name"]
|
||||
|
||||
retrieved = await temp_prompt_store.get_prompt(prompt.prompt_id)
|
||||
assert retrieved.prompt_id == prompt.prompt_id
|
||||
assert retrieved.prompt == prompt.prompt
|
||||
|
||||
async def test_update_prompt(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("Original")
|
||||
updated = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", 1, ["v"])
|
||||
assert updated.version == 2
|
||||
assert updated.prompt == "Updated"
|
||||
|
||||
async def test_update_prompt_with_version(self, temp_prompt_store):
|
||||
version_for_update = 1
|
||||
|
||||
prompt = await temp_prompt_store.create_prompt("Original")
|
||||
assert prompt.version == 1
|
||||
prompt = await temp_prompt_store.update_prompt(prompt.prompt_id, "Updated", version_for_update, ["v"])
|
||||
assert prompt.version == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# now this is a stale version
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", version_for_update, ["v"])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# this version does not exist
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "Another Update", 99, ["v"])
|
||||
|
||||
async def test_delete_prompt(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("to be deleted")
|
||||
await temp_prompt_store.delete_prompt(prompt.prompt_id)
|
||||
with pytest.raises(ValueError):
|
||||
await temp_prompt_store.get_prompt(prompt.prompt_id)
|
||||
|
||||
async def test_list_prompts(self, temp_prompt_store):
|
||||
response = await temp_prompt_store.list_prompts()
|
||||
assert response.data == []
|
||||
|
||||
await temp_prompt_store.create_prompt("first")
|
||||
await temp_prompt_store.create_prompt("second")
|
||||
|
||||
response = await temp_prompt_store.list_prompts()
|
||||
assert len(response.data) == 2
|
||||
|
||||
async def test_version(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("V1")
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||
|
||||
v1 = await temp_prompt_store.get_prompt(prompt.prompt_id, version=1)
|
||||
assert v1.version == 1 and v1.prompt == "V1"
|
||||
|
||||
latest = await temp_prompt_store.get_prompt(prompt.prompt_id)
|
||||
assert latest.version == 2 and latest.prompt == "V2"
|
||||
|
||||
async def test_set_default_version(self, temp_prompt_store):
|
||||
prompt0 = await temp_prompt_store.create_prompt("V1")
|
||||
prompt1 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V2", 1)
|
||||
|
||||
assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 2
|
||||
prompt_default = await temp_prompt_store.set_default_version(prompt0.prompt_id, 1)
|
||||
assert (await temp_prompt_store.get_prompt(prompt0.prompt_id)).version == 1
|
||||
assert prompt_default.version == 1
|
||||
|
||||
prompt2 = await temp_prompt_store.update_prompt(prompt0.prompt_id, "V3", prompt1.version)
|
||||
assert prompt2.version == 3
|
||||
|
||||
async def test_prompt_id_generation_and_validation(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("Test")
|
||||
assert prompt.prompt_id.startswith("pmpt_")
|
||||
assert len(prompt.prompt_id) == 53
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await temp_prompt_store.get_prompt("invalid_id")
|
||||
|
||||
async def test_list_shows_default_versions(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("V1")
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
|
||||
|
||||
response = await temp_prompt_store.list_prompts()
|
||||
listed_prompt = response.data[0]
|
||||
assert listed_prompt.version == 3 and listed_prompt.prompt == "V3"
|
||||
|
||||
await temp_prompt_store.set_default_version(prompt.prompt_id, 1)
|
||||
|
||||
response = await temp_prompt_store.list_prompts()
|
||||
listed_prompt = response.data[0]
|
||||
assert listed_prompt.version == 1 and listed_prompt.prompt == "V1"
|
||||
assert not (await temp_prompt_store.get_prompt(prompt.prompt_id, 3)).is_default
|
||||
|
||||
async def test_get_all_prompt_versions(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("V1")
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1)
|
||||
await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2)
|
||||
|
||||
versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
|
||||
assert len(versions) == 3
|
||||
assert [v.version for v in versions] == [1, 2, 3]
|
||||
assert [v.is_default for v in versions] == [False, False, True]
|
||||
|
||||
await temp_prompt_store.set_default_version(prompt.prompt_id, 2)
|
||||
versions = (await temp_prompt_store.list_prompt_versions(prompt.prompt_id)).data
|
||||
assert [v.is_default for v in versions] == [False, True, False]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await temp_prompt_store.list_prompt_versions("nonexistent")
|
||||
|
||||
async def test_prompt_variable_validation(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("Hello {{ name }}, you live in {{ city }}!", ["name", "city"])
|
||||
assert prompt.variables == ["name", "city"]
|
||||
|
||||
prompt_no_vars = await temp_prompt_store.create_prompt("Hello world!", [])
|
||||
assert prompt_no_vars.variables == []
|
||||
|
||||
with pytest.raises(ValueError, match="undeclared variables"):
|
||||
await temp_prompt_store.create_prompt("Hello {{ name }}, invalid {{ unknown }}!", ["name"])
|
||||
|
||||
async def test_update_prompt_set_as_default_behavior(self, temp_prompt_store):
|
||||
prompt = await temp_prompt_store.create_prompt("V1")
|
||||
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 1
|
||||
|
||||
prompt_v2 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V2", 1, [], set_as_default=True)
|
||||
assert prompt_v2.version == 2
|
||||
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2
|
||||
|
||||
prompt_v3 = await temp_prompt_store.update_prompt(prompt.prompt_id, "V3", 2, [], set_as_default=False)
|
||||
assert prompt_v3.version == 3
|
||||
assert (await temp_prompt_store.get_prompt(prompt.prompt_id)).version == 2
|
|
@ -16,9 +16,11 @@ from llama_stack.apis.agents import (
|
|||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.tools import ListToolsResponse, Tool, ToolGroups, ToolParameter, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
|
||||
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
@ -75,11 +77,11 @@ def sample_agent_config():
|
|||
},
|
||||
input_shields=["string"],
|
||||
output_shields=["string"],
|
||||
toolgroups=["string"],
|
||||
toolgroups=["mcp::my_mcp_server"],
|
||||
client_tools=[
|
||||
{
|
||||
"name": "string",
|
||||
"description": "string",
|
||||
"name": "client_tool",
|
||||
"description": "Client Tool",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "string",
|
||||
|
@ -226,3 +228,83 @@ async def test_delete_agent(agents_impl, sample_agent_config):
|
|||
# Verify the agent was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agent(agent_id)
|
||||
|
||||
|
||||
async def test__initialize_tools(agents_impl, sample_agent_config):
|
||||
# Mock tool_groups_api.list_tools()
|
||||
agents_impl.tool_groups_api.list_tools.return_value = ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier="story_maker",
|
||||
provider_id="model-context-protocol",
|
||||
type=ResourceType.tool,
|
||||
toolgroup_id="mcp::my_mcp_server",
|
||||
description="Make a story",
|
||||
parameters=[
|
||||
ToolParameter(
|
||||
name="story_title",
|
||||
parameter_type="string",
|
||||
description="Title of the story",
|
||||
required=True,
|
||||
title="Story Title",
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_words",
|
||||
parameter_type="array",
|
||||
description="Input words",
|
||||
required=False,
|
||||
items={"type": "string"},
|
||||
title="Input Words",
|
||||
default=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Get an instance of ChatAgent
|
||||
chat_agent = await agents_impl._get_agent_impl(agent_id)
|
||||
assert chat_agent is not None
|
||||
assert isinstance(chat_agent, ChatAgent)
|
||||
|
||||
# Initialize tool definitions
|
||||
await chat_agent._initialize_tools()
|
||||
assert len(chat_agent.tool_defs) == 2
|
||||
|
||||
# Verify the first tool, which is a client tool
|
||||
first_tool = chat_agent.tool_defs[0]
|
||||
assert first_tool.tool_name == "client_tool"
|
||||
assert first_tool.description == "Client Tool"
|
||||
|
||||
# Verify the second tool, which is an MCP tool that has an array-type property
|
||||
second_tool = chat_agent.tool_defs[1]
|
||||
assert second_tool.tool_name == "story_maker"
|
||||
assert second_tool.description == "Make a story"
|
||||
|
||||
parameters = second_tool.parameters
|
||||
assert len(parameters) == 2
|
||||
|
||||
# Verify a string property
|
||||
story_title = parameters.get("story_title")
|
||||
assert story_title is not None
|
||||
assert story_title.param_type == "string"
|
||||
assert story_title.description == "Title of the story"
|
||||
assert story_title.required
|
||||
assert story_title.items is None
|
||||
assert story_title.title == "Story Title"
|
||||
assert story_title.default is None
|
||||
|
||||
# Verify an array property
|
||||
input_words = parameters.get("input_words")
|
||||
assert input_words is not None
|
||||
assert input_words.param_type == "array"
|
||||
assert input_words.description == "Input words"
|
||||
assert not input_words.required
|
||||
assert input_words.items is not None
|
||||
assert len(input_words.items) == 1
|
||||
assert input_words.items.get("type") == "string"
|
||||
assert input_words.title == "Input Words"
|
||||
assert input_words.default == []
|
||||
|
|
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
|||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
|||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
|||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
async def test_validate_input_missing_parameters_chat_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
|||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {"model": "test-model", "prompt": "Hello"},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
|
|
63
tests/unit/providers/inference/bedrock/test_config.py
Normal file
63
tests/unit/providers/inference/bedrock/test_config.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
class TestBedrockBaseConfig:
|
||||
def test_defaults_work_without_env_vars(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = BedrockBaseConfig()
|
||||
|
||||
# Basic creds should be None
|
||||
assert config.aws_access_key_id is None
|
||||
assert config.aws_secret_access_key is None
|
||||
assert config.region_name is None
|
||||
|
||||
# Timeouts get defaults
|
||||
assert config.connect_timeout == 60.0
|
||||
assert config.read_timeout == 60.0
|
||||
assert config.session_ttl == 3600
|
||||
|
||||
def test_env_vars_get_picked_up(self):
|
||||
env_vars = {
|
||||
"AWS_ACCESS_KEY_ID": "AKIATEST123",
|
||||
"AWS_SECRET_ACCESS_KEY": "secret123",
|
||||
"AWS_DEFAULT_REGION": "us-west-2",
|
||||
"AWS_MAX_ATTEMPTS": "5",
|
||||
"AWS_RETRY_MODE": "adaptive",
|
||||
"AWS_CONNECT_TIMEOUT": "30",
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, env_vars, clear=True):
|
||||
config = BedrockBaseConfig()
|
||||
|
||||
assert config.aws_access_key_id == "AKIATEST123"
|
||||
assert config.aws_secret_access_key == "secret123"
|
||||
assert config.region_name == "us-west-2"
|
||||
assert config.total_max_attempts == 5
|
||||
assert config.retry_mode == "adaptive"
|
||||
assert config.connect_timeout == 30.0
|
||||
|
||||
def test_partial_env_setup(self):
|
||||
# Just setting one timeout var
|
||||
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
|
||||
config = BedrockBaseConfig()
|
||||
|
||||
assert config.connect_timeout == 120.0
|
||||
assert config.read_timeout == 60.0 # still default
|
||||
assert config.aws_access_key_id is None
|
||||
|
||||
def test_bad_max_attempts_breaks(self):
|
||||
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
|
||||
try:
|
||||
BedrockBaseConfig()
|
||||
raise AssertionError("Should have failed on bad int conversion")
|
||||
except ValueError:
|
||||
pass # expected
|
|
@ -33,8 +33,7 @@ def test_groq_provider_openai_client_caching():
|
|||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
openai_client = inference_adapter._get_openai_client()
|
||||
assert openai_client.api_key == api_key
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
def test_openai_provider_openai_client_caching():
|
||||
|
|
|
@ -26,7 +26,6 @@ class TestProviderDataValidator(BaseModel):
|
|||
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: TestConfig):
|
||||
super().__init__(
|
||||
model_entries=[],
|
||||
litellm_provider_name="test",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="test_api_key",
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
|
@ -80,11 +80,22 @@ class TestOpenAIBaseURLConfig:
|
|||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.retrieve method
|
||||
# Mock a model object that will be returned by models.list()
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = "gpt-4"
|
||||
|
||||
# Create an async iterator that yields our mock model
|
||||
async def mock_async_iterator():
|
||||
yield mock_model
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.list method
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Set the __provider_id__ attribute that's expected by list_models
|
||||
adapter.__provider_id__ = "openai"
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
|
@ -94,8 +105,8 @@ class TestOpenAIBaseURLConfig:
|
|||
base_url=custom_url,
|
||||
)
|
||||
|
||||
# Verify the method was called and returned True
|
||||
mock_client.models.retrieve.assert_called_once_with("gpt-4")
|
||||
# Verify the models.list method was called
|
||||
mock_client.models.list.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
|
@ -110,11 +121,22 @@ class TestOpenAIBaseURLConfig:
|
|||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
# Mock a model object that will be returned by models.list()
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = "gpt-4"
|
||||
|
||||
# Create an async iterator that yields our mock model
|
||||
async def mock_async_iterator():
|
||||
yield mock_model
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.list method
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Set the __provider_id__ attribute that's expected by list_models
|
||||
adapter.__provider_id__ = "openai"
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
|
|
|
@ -6,19 +6,15 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging # allow-direct-logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as OpenAIChoice,
|
||||
Choice as OpenAIChoiceChunk,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDelta as OpenAIChoiceDelta,
|
||||
|
@ -35,6 +31,9 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionRequest,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
SystemMessage,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
|
@ -61,52 +60,21 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
|
|||
# -v -s --tb=short --disable-warnings
|
||||
|
||||
|
||||
class MockInferenceAdapterWithSleep:
|
||||
def __init__(self, sleep_time: int, response: dict[str, Any]):
|
||||
self.httpd = None
|
||||
|
||||
class DelayedRequestHandler(BaseHTTPRequestHandler):
|
||||
# ruff: noqa: N802
|
||||
def do_POST(self):
|
||||
time.sleep(sleep_time)
|
||||
response_body = json.dumps(response).encode("utf-8")
|
||||
self.send_response(code=200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", len(response_body))
|
||||
self.end_headers()
|
||||
self.wfile.write(response_body)
|
||||
|
||||
self.request_handler = DelayedRequestHandler
|
||||
|
||||
def __enter__(self):
|
||||
httpd = HTTPServer(("", 0), self.request_handler)
|
||||
self.httpd = httpd
|
||||
host, port = httpd.server_address
|
||||
httpd_thread = threading.Thread(target=httpd.serve_forever)
|
||||
httpd_thread.daemon = True # stop server if this thread terminates
|
||||
httpd_thread.start()
|
||||
|
||||
config = VLLMInferenceAdapterConfig(url=f"http://{host}:{port}")
|
||||
inference_adapter = VLLMInferenceAdapter(config)
|
||||
return inference_adapter
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
if self.httpd:
|
||||
self.httpd.shutdown()
|
||||
self.httpd.server_close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_openai_models_list():
|
||||
with patch("openai.resources.models.AsyncModels.list", new_callable=AsyncMock) as mock_list:
|
||||
with patch("openai.resources.models.AsyncModels.list") as mock_list:
|
||||
yield mock_list
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
async def vllm_inference_adapter():
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
inference_adapter = VLLMInferenceAdapter(config)
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
# Mock the __provider_spec__ attribute that would normally be set by the resolver
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_type = "vllm-inference"
|
||||
inference_adapter.__provider_spec__.provider_data_validator = MagicMock()
|
||||
await inference_adapter.initialize()
|
||||
return inference_adapter
|
||||
|
||||
|
@ -150,10 +118,16 @@ async def test_tool_call_response(vllm_inference_adapter):
|
|||
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||
into the expected JSON format."""
|
||||
|
||||
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||
with patch.object(
|
||||
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_nonstream_completion:
|
||||
# Patch the client property to avoid instantiating a real AsyncOpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock()
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Mock the model to return a proper provider_resource_id
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
|
||||
vllm_inference_adapter.model_store.get_model.return_value = mock_model
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
UserMessage(content="How many?"),
|
||||
|
@ -179,7 +153,7 @@ async def test_tool_call_response(vllm_inference_adapter):
|
|||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||
assert mock_client.chat.completions.create.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||
{
|
||||
"id": "foo",
|
||||
"type": "function",
|
||||
|
@ -199,7 +173,7 @@ async def test_tool_call_delta_empty_tool_call_buf():
|
|||
|
||||
async def mock_stream():
|
||||
delta = OpenAIChoiceDelta(content="", tool_calls=None)
|
||||
choices = [OpenAIChoice(delta=delta, finish_reason="stop", index=0)]
|
||||
choices = [OpenAIChoiceChunk(delta=delta, finish_reason="stop", index=0)]
|
||||
mock_chunk = OpenAIChatCompletionChunk(
|
||||
id="chunk-1",
|
||||
created=1,
|
||||
|
@ -225,7 +199,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
|
|||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
|
@ -250,7 +224,7 @@ async def test_tool_call_delta_streaming_arguments_dict():
|
|||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
|
@ -275,7 +249,9 @@ async def test_tool_call_delta_streaming_arguments_dict():
|
|||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
|
||||
)
|
||||
],
|
||||
)
|
||||
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
|
||||
|
@ -299,7 +275,7 @@ async def test_multiple_tool_calls():
|
|||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
|
@ -324,7 +300,7 @@ async def test_multiple_tool_calls():
|
|||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(
|
||||
content="",
|
||||
tool_calls=[
|
||||
|
@ -349,7 +325,9 @@ async def test_multiple_tool_calls():
|
|||
model="foo",
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
|
||||
OpenAIChoiceChunk(
|
||||
delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0
|
||||
)
|
||||
],
|
||||
)
|
||||
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
|
||||
|
@ -393,59 +371,6 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
|
|||
assert chunks[0].event.event_type.value == "start"
|
||||
|
||||
|
||||
@pytest.mark.allow_network
|
||||
def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.set_debug(True)
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Log when event loop is blocked for more than 200ms
|
||||
loop.slow_callback_duration = 0.5
|
||||
# Sleep for 500ms in our delayed http response
|
||||
sleep_time = 0.5
|
||||
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
|
||||
mock_response = {
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1,
|
||||
"modle": "mock-model",
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": ""},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def do_chat_completion():
|
||||
await inference_adapter.chat_completion(
|
||||
"mock-model",
|
||||
[],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
with MockInferenceAdapterWithSleep(sleep_time, mock_response) as inference_adapter:
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
inference_adapter.model_store.get_model.return_value = mock_model
|
||||
loop.run_until_complete(inference_adapter.initialize())
|
||||
|
||||
# Clear the logs so far and run the actual chat completion we care about
|
||||
caplog.clear()
|
||||
loop.run_until_complete(do_chat_completion())
|
||||
|
||||
# Ensure we don't have any asyncio warnings in the captured log
|
||||
# records from our chat completion call. A message gets logged
|
||||
# here any time we exceed the slow_callback_duration configured
|
||||
# above.
|
||||
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
||||
assert not asyncio_warnings
|
||||
|
||||
|
||||
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||
request = ChatCompletionRequest(
|
||||
tools=[],
|
||||
|
@ -638,33 +563,29 @@ async def test_health_status_success(vllm_inference_adapter):
|
|||
"""
|
||||
Test the health method of VLLM InferenceAdapter when the connection is successful.
|
||||
|
||||
This test verifies that the health method returns a HealthResponse with status OK, only
|
||||
when the connection to the vLLM server is successful.
|
||||
This test verifies that the health method returns a HealthResponse with status OK
|
||||
when the /health endpoint responds successfully.
|
||||
"""
|
||||
# Set vllm_inference_adapter.client to None to ensure _create_client is called
|
||||
vllm_inference_adapter.client = None
|
||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||
# Create mock client and models
|
||||
mock_client = MagicMock()
|
||||
mock_models = MagicMock()
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
# Create mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
# Create a mock async iterator that yields a model when iterated
|
||||
async def mock_list():
|
||||
for model in [MagicMock()]:
|
||||
yield model
|
||||
|
||||
# Set up the models.list to return our mock async iterator
|
||||
mock_models.list.return_value = mock_list()
|
||||
mock_client.models = mock_models
|
||||
mock_create_client.return_value = mock_client
|
||||
# Create mock client instance
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.get = AsyncMock(return_value=mock_response)
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
|
||||
# Verify the response
|
||||
assert health_response["status"] == HealthStatus.OK
|
||||
|
||||
# Verify that models.list was called
|
||||
mock_models.list.assert_called_once()
|
||||
# Verify that the health endpoint was called
|
||||
mock_client_instance.get.assert_called_once()
|
||||
call_args = mock_client_instance.get.call_args[0]
|
||||
assert call_args[0].endswith("/health")
|
||||
|
||||
|
||||
async def test_health_status_failure(vllm_inference_adapter):
|
||||
|
@ -674,26 +595,190 @@ async def test_health_status_failure(vllm_inference_adapter):
|
|||
This test verifies that the health method returns a HealthResponse with status ERROR
|
||||
and an appropriate error message when the connection to the vLLM server fails.
|
||||
"""
|
||||
vllm_inference_adapter.client = None
|
||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||
# Create mock client and models
|
||||
mock_client = MagicMock()
|
||||
mock_models = MagicMock()
|
||||
|
||||
# Create a mock async iterator that raises an exception when iterated
|
||||
async def mock_list():
|
||||
raise Exception("Connection failed")
|
||||
yield # Unreachable code
|
||||
|
||||
# Set up the models.list to return our mock async iterator
|
||||
mock_models.list.return_value = mock_list()
|
||||
mock_client.models = mock_models
|
||||
mock_create_client.return_value = mock_client
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
# Create mock client instance that raises an exception
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.get.side_effect = Exception("Connection failed")
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
|
||||
# Verify the response
|
||||
assert health_response["status"] == HealthStatus.ERROR
|
||||
assert "Health check failed: Connection failed" in health_response["message"]
|
||||
|
||||
mock_models.list.assert_called_once()
|
||||
|
||||
async def test_health_status_no_static_api_key(vllm_inference_adapter):
|
||||
"""
|
||||
Test the health method of VLLM InferenceAdapter when no static API key is provided.
|
||||
|
||||
This test verifies that the health method returns a HealthResponse with status OK
|
||||
when the /health endpoint responds successfully, regardless of API token configuration.
|
||||
"""
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
# Create mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
# Create mock client instance
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.get = AsyncMock(return_value=mock_response)
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
|
||||
# Verify the response
|
||||
assert health_response["status"] == HealthStatus.OK
|
||||
|
||||
|
||||
async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
||||
"""
|
||||
Verify that openai_chat_completion is async and doesn't block the event loop.
|
||||
|
||||
To do this we mock the underlying inference with a sleep, start multiple
|
||||
inference calls in parallel, and ensure the total time taken is less
|
||||
than the sum of the individual sleep times.
|
||||
"""
|
||||
sleep_time = 0.5
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
await asyncio.sleep(sleep_time)
|
||||
return OpenAIChatCompletion(
|
||||
id="chatcmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content="nothing interesting",
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def do_inference():
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model", messages=["one fish", "two fish"], stream=False
|
||||
)
|
||||
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(side_effect=mock_create)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
start_time = time.time()
|
||||
await asyncio.gather(do_inference(), do_inference(), do_inference(), do_inference())
|
||||
total_time = time.time() - start_time
|
||||
|
||||
assert mock_create_client.call_count == 4 # no cheating
|
||||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||
|
||||
|
||||
async def test_should_refresh_models():
|
||||
"""
|
||||
Test the should_refresh_models method with different refresh_models configurations.
|
||||
|
||||
This test verifies that:
|
||||
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
|
||||
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
|
||||
"""
|
||||
|
||||
# Test case 1: refresh_models is True, api_token is None
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
|
||||
adapter1 = VLLMInferenceAdapter(config1)
|
||||
result1 = await adapter1.should_refresh_models()
|
||||
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 2: refresh_models is True, api_token is empty string
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
|
||||
adapter2 = VLLMInferenceAdapter(config2)
|
||||
result2 = await adapter2.should_refresh_models()
|
||||
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 3: refresh_models is True, api_token is "fake" (default)
|
||||
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
|
||||
adapter3 = VLLMInferenceAdapter(config3)
|
||||
result3 = await adapter3.should_refresh_models()
|
||||
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 4: refresh_models is True, api_token is real token
|
||||
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
|
||||
adapter4 = VLLMInferenceAdapter(config4)
|
||||
result4 = await adapter4.should_refresh_models()
|
||||
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 5: refresh_models is False, api_token is real token
|
||||
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
|
||||
adapter5 = VLLMInferenceAdapter(config5)
|
||||
result5 = await adapter5.should_refresh_models()
|
||||
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
|
||||
|
||||
|
||||
async def test_provider_data_var_context_propagation(vllm_inference_adapter):
|
||||
"""
|
||||
Test that PROVIDER_DATA_VAR context is properly propagated through the vLLM inference adapter.
|
||||
This ensures that dynamic provider data (like API tokens) can be passed through context.
|
||||
Note: The base URL is always taken from config.url, not from provider data.
|
||||
"""
|
||||
# Mock the AsyncOpenAI class to capture provider data
|
||||
with (
|
||||
patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class,
|
||||
patch.object(vllm_inference_adapter, "get_request_provider_data") as mock_get_provider_data,
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.chat.completions.create = AsyncMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Mock provider data to return test data
|
||||
mock_provider_data = MagicMock()
|
||||
mock_provider_data.vllm_api_token = "test-token-123"
|
||||
mock_provider_data.vllm_url = "http://test-server:8000/v1"
|
||||
mock_get_provider_data.return_value = mock_provider_data
|
||||
|
||||
# Mock the model
|
||||
mock_model = Model(identifier="test-model", provider_resource_id="test-model", provider_id="vllm-inference")
|
||||
vllm_inference_adapter.model_store.get_model.return_value = mock_model
|
||||
|
||||
try:
|
||||
# Execute chat completion
|
||||
await vllm_inference_adapter.chat_completion(
|
||||
"test-model",
|
||||
[UserMessage(content="Hello")],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
# Verify that ALL client calls were made with the correct parameters
|
||||
calls = mock_openai_class.call_args_list
|
||||
incorrect_calls = []
|
||||
|
||||
for i, call in enumerate(calls):
|
||||
api_key = call[1]["api_key"]
|
||||
base_url = call[1]["base_url"]
|
||||
|
||||
if api_key != "test-token-123" or base_url != "http://mocked.localhost:12345":
|
||||
incorrect_calls.append({"call_index": i, "api_key": api_key, "base_url": base_url})
|
||||
|
||||
if incorrect_calls:
|
||||
error_msg = (
|
||||
f"Found {len(incorrect_calls)} calls with incorrect parameters out of {len(calls)} total calls:\n"
|
||||
)
|
||||
for incorrect_call in incorrect_calls:
|
||||
error_msg += f" Call {incorrect_call['call_index']}: api_key='{incorrect_call['api_key']}', base_url='{incorrect_call['base_url']}'\n"
|
||||
error_msg += "Expected: api_key='test-token-123', base_url='http://mocked.localhost:12345'"
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
# Ensure at least one call was made
|
||||
assert len(calls) >= 1, "No AsyncOpenAI client calls were made"
|
||||
|
||||
# Verify that chat completion was called
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
finally:
|
||||
# Clean up context
|
||||
pass
|
||||
|
|
|
@ -52,14 +52,19 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
|
|||
self.evaluator_post_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||
)
|
||||
self.evaluator_delete_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_delete"
|
||||
)
|
||||
|
||||
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||
self.mock_evaluator_delete = self.evaluator_delete_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.evaluator_get_patcher.stop()
|
||||
self.evaluator_post_patcher.stop()
|
||||
self.evaluator_delete_patcher.stop()
|
||||
|
||||
def _assert_request_body(self, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
|
@ -115,6 +120,13 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
|
|||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||
|
||||
def test_unregister_benchmark(self):
|
||||
# Unregister the benchmark
|
||||
self.run_async(self.eval_impl.unregister_benchmark(benchmark_id=MOCK_BENCHMARK_ID))
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_delete.assert_called_once_with(f"/v1/evaluation/configs/nvidia/{MOCK_BENCHMARK_ID}")
|
||||
|
||||
def test_run_eval(self):
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
|
@ -138,7 +150,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
|
|||
self._assert_request_body(
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||
"target": {"type": "model", "model": "Llama3.1-8B-Instruct"},
|
||||
}
|
||||
)
|
||||
|
||||
|
|
53
tests/unit/providers/test_bedrock.py
Normal file
53
tests/unit/providers/test_bedrock.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.bedrock.bedrock import (
|
||||
_get_region_prefix,
|
||||
_to_inference_profile_id,
|
||||
)
|
||||
|
||||
|
||||
def test_region_prefixes():
|
||||
assert _get_region_prefix("us-east-1") == "us."
|
||||
assert _get_region_prefix("eu-west-1") == "eu."
|
||||
assert _get_region_prefix("ap-south-1") == "ap."
|
||||
assert _get_region_prefix("ca-central-1") == "us."
|
||||
|
||||
# Test case insensitive
|
||||
assert _get_region_prefix("US-EAST-1") == "us."
|
||||
assert _get_region_prefix("EU-WEST-1") == "eu."
|
||||
assert _get_region_prefix("Ap-South-1") == "ap."
|
||||
|
||||
# Test None region
|
||||
assert _get_region_prefix(None) == "us."
|
||||
|
||||
|
||||
def test_model_id_conversion():
|
||||
# Basic conversion
|
||||
assert (
|
||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
||||
|
||||
# Already has prefix
|
||||
assert (
|
||||
_to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1")
|
||||
== "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
||||
|
||||
# ARN should be returned unchanged
|
||||
arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0"
|
||||
assert _to_inference_profile_id(arn, "us-east-1") == arn
|
||||
|
||||
# ARN should be returned unchanged even without region
|
||||
assert _to_inference_profile_id(arn) == arn
|
||||
|
||||
# Optional region parameter defaults to us-east-1
|
||||
assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0"
|
||||
|
||||
# Different regions work with optional parameter
|
||||
assert (
|
||||
_to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0"
|
||||
)
|
368
tests/unit/providers/utils/inference/test_openai_mixin.py
Normal file
368
tests/unit/providers/utils/inference/test_openai_mixin.py
Normal file
|
@ -0,0 +1,368 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
def __init__(self):
|
||||
self.__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
||||
"""Test implementation with embedding model metadata"""
|
||||
|
||||
embedding_model_metadata = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
__provider_id__ = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
mixin_instance = OpenAIMixinImpl()
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_embeddings():
|
||||
"""Create a test instance of OpenAIMixin with embedding model metadata"""
|
||||
return OpenAIMixinWithEmbeddingsImpl()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models():
|
||||
"""Create multiple mock OpenAI model objects"""
|
||||
models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]]
|
||||
return models
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_models(mock_models):
|
||||
"""Create a mock client with models.list() set up to return mock_models"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_empty_models():
|
||||
"""Create a mock client with models.list() set up to return empty list"""
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_empty_models_list():
|
||||
return
|
||||
yield # Make it an async generator but don't yield anything
|
||||
|
||||
mock_client.models.list.return_value = mock_empty_models_list()
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_with_exception():
|
||||
"""Create a mock client with models.list() set up to raise an exception"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.list.side_effect = Exception("API Error")
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_context():
|
||||
"""Fixture that provides a context manager for mocking the OpenAI client"""
|
||||
|
||||
def _mock_client_context(mixin, mock_client):
|
||||
return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client)
|
||||
|
||||
return _mock_client_context
|
||||
|
||||
|
||||
class TestOpenAIMixinListModels:
|
||||
"""Test cases for the list_models method"""
|
||||
|
||||
async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test successful model listing"""
|
||||
assert len(mixin._model_cache) == 0
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
model_ids = [model.identifier for model in result]
|
||||
assert "some-mock-model-id" in model_ids
|
||||
assert "another-mock-model-id" in model_ids
|
||||
assert "final-mock-model-id" in model_ids
|
||||
|
||||
for model in result:
|
||||
assert model.provider_id == "test-provider"
|
||||
assert model.model_type == ModelType.llm
|
||||
assert model.provider_resource_id == model.identifier
|
||||
|
||||
assert len(mixin._model_cache) == 3
|
||||
for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]:
|
||||
assert model_id in mixin._model_cache
|
||||
cached_model = mixin._model_cache[model_id]
|
||||
assert cached_model.identifier == model_id
|
||||
assert cached_model.provider_resource_id == model_id
|
||||
|
||||
async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context):
|
||||
"""Test handling of empty model list"""
|
||||
with mock_client_context(mixin, mock_client_with_empty_models):
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 0
|
||||
assert len(mixin._model_cache) == 0
|
||||
|
||||
|
||||
class TestOpenAIMixinCheckModelAvailability:
|
||||
"""Test cases for the check_model_availability method"""
|
||||
|
||||
async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model availability check when cache is populated"""
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
await mixin.list_models()
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
assert await mixin.check_model_availability("another-mock-model-id")
|
||||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
assert not await mixin.check_model_availability("non-existent-model")
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model availability check when cache is empty (calls list_models)"""
|
||||
assert len(mixin._model_cache) == 0
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
assert len(mixin._model_cache) == 3
|
||||
assert "some-mock-model-id" in mixin._model_cache
|
||||
|
||||
async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model availability check for non-existent model"""
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert not await mixin.check_model_availability("non-existent-model")
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
|
||||
assert len(mixin._model_cache) == 3
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
"""Test cases for cache behavior and edge cases"""
|
||||
|
||||
async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that calling list_models overwrites existing cache"""
|
||||
initial_model = Model(
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="old-model",
|
||||
identifier="old-model",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
mixin._model_cache = {"old-model": initial_model}
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
await mixin.list_models()
|
||||
|
||||
assert len(mixin._model_cache) == 3
|
||||
assert "old-model" not in mixin._model_cache
|
||||
assert "some-mock-model-id" in mixin._model_cache
|
||||
assert "another-mock-model-id" in mixin._model_cache
|
||||
assert "final-mock-model-id" in mixin._model_cache
|
||||
|
||||
|
||||
class TestOpenAIMixinImagePreprocessing:
|
||||
"""Test cases for image preprocessing functionality"""
|
||||
|
||||
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
|
||||
"""Test that image URLs are converted to base64 when download_images is True"""
|
||||
mixin.download_images = True
|
||||
|
||||
message = OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
|
||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
processed_messages = call_args[1]["messages"]
|
||||
assert len(processed_messages) == 1
|
||||
content = processed_messages[0]["content"]
|
||||
assert len(content) == 2
|
||||
assert content[0]["type"] == "text"
|
||||
assert content[1]["type"] == "image_url"
|
||||
assert content[1]["image_url"]["url"] == ""
|
||||
|
||||
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
|
||||
"""Test that image URLs are not modified when download_images is False"""
|
||||
mixin.download_images = False # explicitly set to False
|
||||
|
||||
message = OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
|
||||
mock_localize.assert_not_called()
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
processed_messages = call_args[1]["messages"]
|
||||
assert len(processed_messages) == 1
|
||||
content = processed_messages[0]["content"]
|
||||
assert len(content) == 2
|
||||
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
class TestOpenAIMixinEmbeddingModelMetadata:
|
||||
"""Test cases for embedding_model_metadata attribute functionality"""
|
||||
|
||||
async def test_embedding_model_identified_and_augmented(self, mixin_with_embeddings, mock_client_context):
|
||||
"""Test that models in embedding_model_metadata are correctly identified as embeddings with metadata"""
|
||||
# Create mock models: 1 embedding model and 1 LLM, while there are 2 known embedding models
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_llm_model = MagicMock(id="gpt-4")
|
||||
mock_models = [mock_embedding_model, mock_llm_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
with mock_client_context(mixin_with_embeddings, mock_client):
|
||||
result = await mixin_with_embeddings.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
# Find the models in the result
|
||||
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
|
||||
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||
|
||||
# Check embedding model
|
||||
assert embedding_model.model_type == ModelType.embedding
|
||||
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
|
||||
assert embedding_model.provider_id == "test-provider"
|
||||
assert embedding_model.provider_resource_id == "text-embedding-3-small"
|
||||
|
||||
# Check LLM model
|
||||
assert llm_model.model_type == ModelType.llm
|
||||
assert llm_model.metadata == {} # No metadata for LLMs
|
||||
assert llm_model.provider_id == "test-provider"
|
||||
assert llm_model.provider_resource_id == "gpt-4"
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModels:
|
||||
"""Test cases for allowed_models filtering functionality"""
|
||||
|
||||
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that list_models filters models based on allowed_models set"""
|
||||
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
model_ids = [model.identifier for model in result]
|
||||
assert "some-mock-model-id" in model_ids
|
||||
assert "another-mock-model-id" in model_ids
|
||||
assert "final-mock-model-id" not in model_ids
|
||||
|
||||
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that empty allowed_models set allows all models"""
|
||||
assert len(mixin.allowed_models) == 0
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3 # All models should be included
|
||||
|
||||
model_ids = [model.identifier for model in result]
|
||||
assert "some-mock-model-id" in model_ids
|
||||
assert "another-mock-model-id" in model_ids
|
||||
assert "final-mock-model-id" in model_ids
|
||||
|
||||
async def test_check_model_availability_with_allowed_models(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability respects allowed_models"""
|
||||
mixin.allowed_models = {"final-mock-model-id"}
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
assert not await mixin.check_model_availability("some-mock-model-id")
|
||||
assert not await mixin.check_model_availability("another-mock-model-id")
|
|
@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
|
|||
# Should raise an exception instead of returning empty string
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
content_from_data_and_mime_type(data, mime_type)
|
||||
|
||||
|
||||
async def test_memory_tool_error_handling():
|
||||
"""Test that memory tool handles various failures gracefully without crashing."""
|
||||
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
|
||||
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||
|
||||
config = RagToolRuntimeConfig()
|
||||
memory_tool = MemoryToolRuntimeImpl(
|
||||
config=config,
|
||||
vector_io_api=AsyncMock(),
|
||||
inference_api=AsyncMock(),
|
||||
files_api=AsyncMock(),
|
||||
)
|
||||
|
||||
docs = [
|
||||
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
|
||||
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
|
||||
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
|
||||
]
|
||||
|
||||
mock_file1 = MagicMock()
|
||||
mock_file1.id = "file_good1"
|
||||
mock_file2 = MagicMock()
|
||||
mock_file2.id = "file_good2"
|
||||
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.side_effect = Exception("Bad URL")
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
|
||||
# won't raise exception despite one document failing
|
||||
await memory_tool.insert(docs, "vector_store_123")
|
||||
|
||||
# processed 2 documents successfully, skipped 1
|
||||
assert memory_tool.files_api.openai_upload_file.call_count == 2
|
||||
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2
|
||||
|
|
|
@ -84,14 +84,14 @@ def unknown_model() -> Model:
|
|||
|
||||
@pytest.fixture
|
||||
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
|
||||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||
return ModelRegistryHelper(model_entries=[known_provider_model, known_provider_model2])
|
||||
|
||||
|
||||
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
|
||||
"""Test helper that simulates a provider with dynamically available models."""
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
|
||||
super().__init__(model_entries)
|
||||
super().__init__(model_entries=model_entries)
|
||||
self._available_models = available_models
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
|
|
|
@ -54,7 +54,9 @@ def mock_vector_db(vector_db_id) -> MagicMock:
|
|||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = 384
|
||||
mock_vector_db.model_dump_json.return_value = (
|
||||
'{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}'
|
||||
'{"identifier": "'
|
||||
+ vector_db_id
|
||||
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
|
||||
)
|
||||
return mock_vector_db
|
||||
|
||||
|
|
|
@ -26,9 +26,9 @@ def test_generate_chunk_id():
|
|||
|
||||
chunk_ids = sorted([chunk.chunk_id for chunk in chunks])
|
||||
assert chunk_ids == [
|
||||
"177a1368-f6a8-0c50-6e92-18677f2c3de3",
|
||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||
"31d1f9a3-c8d2-66e7-3c37-af2acd329778",
|
||||
"d07dade7-29c0-cda7-df29-0249a1dcbc3e",
|
||||
"d14f75a1-5855-7f72-2c78-d9fc4275a346",
|
||||
]
|
||||
|
||||
|
||||
|
@ -36,14 +36,14 @@ def test_generate_chunk_id_with_window():
|
|||
chunk = Chunk(content="test", metadata={"document_id": "doc-1"})
|
||||
chunk_id1 = generate_chunk_id("doc-1", chunk, chunk_window="0-1")
|
||||
chunk_id2 = generate_chunk_id("doc-1", chunk, chunk_window="1-2")
|
||||
assert chunk_id1 == "149018fe-d0eb-0f8d-5f7f-726bdd2aeedb"
|
||||
assert chunk_id2 == "4562c1ee-9971-1f3b-51a6-7d05e5211154"
|
||||
assert chunk_id1 == "8630321a-d9cb-2bb6-cd28-ebf68dafd866"
|
||||
assert chunk_id2 == "13a1c09a-cbda-b61a-2d1a-7baa90888685"
|
||||
|
||||
|
||||
def test_chunk_id():
|
||||
# Test with existing chunk ID
|
||||
chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"})
|
||||
assert chunk_with_id.chunk_id == "84ededcc-b80b-a83e-1a20-ca6515a11350"
|
||||
assert chunk_with_id.chunk_id == "11704f92-42b6-61df-bf85-6473e7708fbd"
|
||||
|
||||
# Test with document ID in metadata
|
||||
chunk_with_doc_id = Chunk(content="test", metadata={"document_id": "doc-1"})
|
||||
|
|
|
@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
|||
|
||||
class TestRagQuery:
|
||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||
|
||||
async def test_query_chunk_metadata_handling(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||
)
|
||||
content = "test query content"
|
||||
vector_db_ids = ["db1"]
|
||||
|
||||
|
@ -77,3 +81,58 @@ class TestRagQuery:
|
|||
# Test that invalid mode raises an error
|
||||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="wrong_mode")
|
||||
|
||||
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
||||
rag_tool = MemoryToolRuntimeImpl(
|
||||
config=MagicMock(),
|
||||
vector_io_api=MagicMock(),
|
||||
inference_api=MagicMock(),
|
||||
files_api=MagicMock(),
|
||||
)
|
||||
|
||||
vector_db_ids = ["db1", "db2"]
|
||||
|
||||
# Fake chunks from each DB
|
||||
chunk_metadata1 = ChunkMetadata(
|
||||
document_id="doc1",
|
||||
chunk_id="chunk1",
|
||||
source="test_source1",
|
||||
metadata_token_count=5,
|
||||
)
|
||||
chunk1 = Chunk(
|
||||
content="chunk from db1",
|
||||
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
||||
stored_chunk_id="c1",
|
||||
chunk_metadata=chunk_metadata1,
|
||||
)
|
||||
|
||||
chunk_metadata2 = ChunkMetadata(
|
||||
document_id="doc2",
|
||||
chunk_id="chunk2",
|
||||
source="test_source2",
|
||||
metadata_token_count=5,
|
||||
)
|
||||
chunk2 = Chunk(
|
||||
content="chunk from db2",
|
||||
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
||||
stored_chunk_id="c2",
|
||||
chunk_metadata=chunk_metadata2,
|
||||
)
|
||||
|
||||
rag_tool.vector_io_api.query_chunks = AsyncMock(
|
||||
side_effect=[
|
||||
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
|
||||
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
|
||||
]
|
||||
)
|
||||
|
||||
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
||||
returned_chunks = result.metadata["chunks"]
|
||||
returned_scores = result.metadata["scores"]
|
||||
returned_doc_ids = result.metadata["document_ids"]
|
||||
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||
|
||||
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||
assert returned_scores == (0.9, 0.8)
|
||||
assert returned_doc_ids == ["doc1", "doc2"]
|
||||
assert returned_vector_db_ids == ["db1", "db2"]
|
||||
|
|
|
@ -774,3 +774,136 @@ def test_has_required_scope_function():
|
|||
|
||||
# Test no user (auth disabled)
|
||||
assert _has_required_scope("test.read", None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kubernetes_api_server():
|
||||
return "https://api.cluster.example.com:6443"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kubernetes_auth_app(mock_kubernetes_api_server):
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config={
|
||||
"type": "kubernetes",
|
||||
"api_server_url": mock_kubernetes_api_server,
|
||||
"verify_tls": False,
|
||||
"claims_mapping": {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
"uid": "uid_attr",
|
||||
},
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kubernetes_auth_client(kubernetes_auth_app):
|
||||
return TestClient(kubernetes_auth_app)
|
||||
|
||||
|
||||
def test_missing_auth_header_kubernetes_auth(kubernetes_auth_client):
|
||||
response = kubernetes_auth_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_kubernetes_auth(kubernetes_auth_client):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_kubernetes_selfsubjectreview_success(*args, **kwargs):
|
||||
return MockResponse(
|
||||
201,
|
||||
{
|
||||
"apiVersion": "authentication.k8s.io/v1",
|
||||
"kind": "SelfSubjectReview",
|
||||
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||
"status": {
|
||||
"userInfo": {
|
||||
"username": "alice",
|
||||
"uid": "alice-uid-123",
|
||||
"groups": ["system:authenticated", "developers", "admins"],
|
||||
"extra": {"scopes.authorization.openshift.io": ["user:full"]},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_kubernetes_selfsubjectreview_failure(*args, **kwargs):
|
||||
return MockResponse(401, {"message": "Unauthorized"})
|
||||
|
||||
|
||||
async def mock_kubernetes_selfsubjectreview_http_error(*args, **kwargs):
|
||||
return MockResponse(500, {"message": "Internal Server Error"})
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_success)
|
||||
def test_valid_kubernetes_auth_authentication(kubernetes_auth_client, valid_token):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_failure)
|
||||
def test_invalid_kubernetes_auth_authentication(kubernetes_auth_client, invalid_token):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_kubernetes_selfsubjectreview_http_error)
|
||||
def test_kubernetes_auth_http_error(kubernetes_auth_client, valid_token):
|
||||
response = kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Token validation failed" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_kubernetes_auth_request_payload(kubernetes_auth_client, valid_token, mock_kubernetes_api_server):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
{
|
||||
"apiVersion": "authentication.k8s.io/v1",
|
||||
"kind": "SelfSubjectReview",
|
||||
"metadata": {"creationTimestamp": "2025-07-15T13:53:56Z"},
|
||||
"status": {
|
||||
"userInfo": {
|
||||
"username": "test-user",
|
||||
"uid": "test-uid",
|
||||
"groups": ["test-group"],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
kubernetes_auth_client.get("/test", headers={"Authorization": f"Bearer {valid_token}"})
|
||||
|
||||
# Verify the request was made with correct parameters
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
|
||||
# Check URL (passed as positional argument)
|
||||
assert call_args[0][0] == f"{mock_kubernetes_api_server}/apis/authentication.k8s.io/v1/selfsubjectreviews"
|
||||
|
||||
# Check headers (passed as keyword argument)
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers["Authorization"] == f"Bearer {valid_token}"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
# Check request body (passed as keyword argument)
|
||||
request_body = call_args[1]["json"]
|
||||
assert request_body["apiVersion"] == "authentication.k8s.io/v1"
|
||||
assert request_body["kind"] == "SelfSubjectReview"
|
||||
|
|
|
@ -113,6 +113,15 @@ class TestTranslateException:
|
|||
assert result.status_code == 504
|
||||
assert result.detail == "Operation timed out: "
|
||||
|
||||
def test_translate_connection_error(self):
|
||||
"""Test that ConnectionError is translated to 502 HTTP status."""
|
||||
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
|
||||
result = translate_exception(exc)
|
||||
|
||||
assert isinstance(result, HTTPException)
|
||||
assert result.status_code == 502
|
||||
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
|
||||
|
||||
def test_translate_not_implemented_error(self):
|
||||
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
||||
exc = NotImplementedError("Not implemented")
|
||||
|
|
|
@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
|
|||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
|
@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
|
|||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
|
@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
|
|||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
|
@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
|
|||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Wait for all queued writes to complete
|
||||
await store.flush()
|
||||
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
|
|
|
@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
|||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
# Create table with access control
|
||||
await sqlstore.create_table(
|
||||
|
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
|||
mock_get_authenticated_user.return_value = admin_user
|
||||
|
||||
# Admin should see both documents
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["title"] == "Admin Document"
|
||||
|
||||
# User should only see their document
|
||||
mock_get_authenticated_user.return_value = regular_user
|
||||
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
|
||||
result = await sqlstore.fetch_all("documents", where={"id": 1})
|
||||
assert len(result.data) == 0
|
||||
|
||||
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
|
||||
result = await sqlstore.fetch_all("documents", where={"id": 2})
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["title"] == "User Document"
|
||||
|
||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
|
||||
row = await sqlstore.fetch_one("documents", where={"id": 1})
|
||||
assert row is None
|
||||
|
||||
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
|
||||
row = await sqlstore.fetch_one("documents", where={"id": 2})
|
||||
assert row is not None
|
||||
assert row["title"] == "User Document"
|
||||
|
||||
|
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore)
|
||||
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
await sqlstore.create_table(
|
||||
table="resources",
|
||||
|
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
|
||||
mock_get_authenticated_user.return_value = user
|
||||
|
||||
sql_results = await sqlstore.fetch_all("resources", policy=policy)
|
||||
sql_results = await sqlstore.fetch_all("resources")
|
||||
sql_ids = {row["id"] for row in sql_results.data}
|
||||
policy_ids = set()
|
||||
for scenario in test_scenarios:
|
||||
|
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
|||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore)
|
||||
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
|
||||
|
||||
await authorized_store.create_table(
|
||||
table="user_data",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue