mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 01:22:37 +00:00
Merge remote-tracking branch 'origin/main' into test_isolation_server
This commit is contained in:
commit
889b2716ef
107 changed files with 817 additions and 1298 deletions
|
|
@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
|
||||
assert non_existent is None
|
||||
|
||||
# Test has_model
|
||||
assert await table.has_model("test_provider/test-model")
|
||||
assert await table.has_model("test_provider/test-model-2")
|
||||
assert not await table.has_model("non-existent-model")
|
||||
assert not await table.has_model("test_provider/non-existent-model")
|
||||
|
||||
await table.unregister_model(model_id="test_provider/test-model")
|
||||
await table.unregister_model(model_id="test_provider/test-model-2")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
|
|
@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
_extract_citations_from_text,
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
|
|
@ -340,3 +342,26 @@ class TestIsFunctionToolCall:
|
|||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestExtractCitationsFromText:
|
||||
def test_extract_citations_and_annotations(self):
|
||||
text = "Start [not-a-file]. New source <|file-abc123|>. "
|
||||
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
|
||||
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
|
||||
|
||||
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
|
||||
|
||||
expected_annotations = [
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
|
||||
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
|
||||
]
|
||||
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
|
||||
|
||||
assert cleaned_text == expected_clean_text
|
||||
assert annotations == expected_annotations
|
||||
# OpenAI cites at the end of the sentence
|
||||
assert cleaned_text[expected_annotations[0].index] == "."
|
||||
assert cleaned_text[expected_annotations[1].index] == "?"
|
||||
assert cleaned_text[expected_annotations[2].index] == "!"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
|||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
|
|||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.client.api_key == api_key
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,adapter_cls,provider_data_validator",
|
||||
[
|
||||
(
|
||||
WatsonXConfig,
|
||||
WatsonXInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||
assumption that there is an OpenAI-compatible client object."""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
||||
for api_key in ["test1", "test2"]:
|
||||
with request_provider_data_context(
|
||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||
):
|
||||
assert inference_adapter.get_api_key() == api_key
|
||||
|
|
|
|||
|
|
@ -186,43 +186,3 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
|
||||
assert mock_create_client.call_count == 4 # no cheating
|
||||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||
|
||||
|
||||
async def test_should_refresh_models():
|
||||
"""
|
||||
Test the should_refresh_models method with different refresh_models configurations.
|
||||
|
||||
This test verifies that:
|
||||
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
|
||||
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
|
||||
"""
|
||||
|
||||
# Test case 1: refresh_models is True, api_token is None
|
||||
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
|
||||
adapter1 = VLLMInferenceAdapter(config=config1)
|
||||
result1 = await adapter1.should_refresh_models()
|
||||
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 2: refresh_models is True, api_token is empty string
|
||||
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
|
||||
adapter2 = VLLMInferenceAdapter(config=config2)
|
||||
result2 = await adapter2.should_refresh_models()
|
||||
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 3: refresh_models is True, api_token is "fake" (default)
|
||||
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
|
||||
adapter3 = VLLMInferenceAdapter(config=config3)
|
||||
result3 = await adapter3.should_refresh_models()
|
||||
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 4: refresh_models is True, api_token is real token
|
||||
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
|
||||
adapter4 = VLLMInferenceAdapter(config=config4)
|
||||
result4 = await adapter4.should_refresh_models()
|
||||
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
|
||||
|
||||
# Test case 5: refresh_models is False, api_token is real token
|
||||
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
|
||||
adapter5 = VLLMInferenceAdapter(config=config5)
|
||||
result5 = await adapter5.should_refresh_models()
|
||||
assert result5 is False, "should_refresh_models should return False when refresh_models is False"
|
||||
|
|
|
|||
|
|
@ -44,11 +44,12 @@ def mixin():
|
|||
config = RemoteInferenceProviderConfig()
|
||||
mixin_instance = OpenAIMixinImpl(config=config)
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
# Mock model_store with async methods
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
|
@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
|
||||
assert len(mixin._model_cache) == 3
|
||||
|
||||
async def test_check_model_availability_with_pre_registered_model(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability returns True for pre-registered models in model_store"""
|
||||
# Mock model_store.has_model to return True for a specific model
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=True)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that pre-registered model is found without calling the provider's API
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability falls back to provider when model not in store"""
|
||||
# Mock model_store.has_model to return False
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model_store.has_model = AsyncMock(return_value=False)
|
||||
mixin.model_store = mock_model_store
|
||||
|
||||
# Test that it falls back to provider's model cache
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
"""Test cases for cache behavior and edge cases"""
|
||||
|
|
@ -466,10 +501,16 @@ class TestOpenAIMixinModelRegistration:
|
|||
assert result is None
|
||||
|
||||
async def test_should_refresh_models(self, mixin):
|
||||
"""Test should_refresh_models method (should always return False)"""
|
||||
"""Test should_refresh_models method returns config value"""
|
||||
# Default config has refresh_models=False
|
||||
result = await mixin.should_refresh_models()
|
||||
assert result is False
|
||||
|
||||
config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True)
|
||||
mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh)
|
||||
result_with_refresh = await mixin_with_refresh.should_refresh_models()
|
||||
assert result_with_refresh is True
|
||||
|
||||
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
|
||||
"""Test that errors from provider API are properly propagated during registration"""
|
||||
model = Model(
|
||||
|
|
|
|||
|
|
@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = ChromaVectorIOConfig(
|
||||
db_path=chroma_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = ChromaVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
config = QdrantVectorIOConfig(
|
||||
db_path=qdrant_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = QdrantVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
|
||||
async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = PGVectorVectorIOConfig(
|
||||
host="localhost",
|
||||
port=5432,
|
||||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
|
@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
|
|
@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi
|
|||
config = WeaviateVectorIOConfig(
|
||||
weaviate_cluster_url="localhost:8080",
|
||||
weaviate_api_key=None,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = WeaviateVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
|
|||
|
|
@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
provider_resource_id="test_vector_db_2",
|
||||
provider_id="baz", # Same provider_id
|
||||
)
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
|
||||
# Now we expect a ValueError to be raised for duplicate registration
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
|
||||
):
|
||||
await cached_disk_dist_registry.register(duplicate_vector_db)
|
||||
|
||||
# Verify the original registration is still intact
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue