Merge remote-tracking branch 'origin/main' into test_isolation_server

This commit is contained in:
Ashwin Bharambe 2025-10-08 11:19:12 -07:00
commit 889b2716ef
107 changed files with 817 additions and 1298 deletions

View file

@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry):
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
# Test has_model
assert await table.has_model("test_provider/test-model")
assert await table.has_model("test_provider/test-model-2")
assert not await table.has_model("non-existent-model")
assert not await table.has_model("test_provider/non-existent-model")
await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test_provider/test-model-2")

View file

@ -8,6 +8,7 @@
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
@ -35,6 +36,7 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
_extract_citations_from_text,
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
@ -340,3 +342,26 @@ class TestIsFunctionToolCall:
result = is_function_tool_call(tool_call, tools)
assert result is False
class TestExtractCitationsFromText:
def test_extract_citations_and_annotations(self):
text = "Start [not-a-file]. New source <|file-abc123|>. "
text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation."
file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"}
annotations, cleaned_text = _extract_citations_from_text(text, file_mapping)
expected_annotations = [
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30),
OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44),
OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59),
]
expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation."
assert cleaned_text == expected_clean_text
assert annotations == expected_annotations
# OpenAI cites at the end of the sentence
assert cleaned_text[expected_annotations[0].index] == "."
assert cleaned_text[expected_annotations[1].index] == "?"
assert cleaned_text[expected_annotations[2].index] == "!"

View file

@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize(
@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
WatsonXConfig,
WatsonXInferenceAdapter,
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
),
],
)
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.get_api_key() == api_key

View file

@ -186,43 +186,3 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_should_refresh_models():
"""
Test the should_refresh_models method with different refresh_models configurations.
This test verifies that:
1. When refresh_models is True, should_refresh_models returns True regardless of api_token
2. When refresh_models is False, should_refresh_models returns False regardless of api_token
"""
# Test case 1: refresh_models is True, api_token is None
config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True)
adapter1 = VLLMInferenceAdapter(config=config1)
result1 = await adapter1.should_refresh_models()
assert result1 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 2: refresh_models is True, api_token is empty string
config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True)
adapter2 = VLLMInferenceAdapter(config=config2)
result2 = await adapter2.should_refresh_models()
assert result2 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 3: refresh_models is True, api_token is "fake" (default)
config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True)
adapter3 = VLLMInferenceAdapter(config=config3)
result3 = await adapter3.should_refresh_models()
assert result3 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 4: refresh_models is True, api_token is real token
config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True)
adapter4 = VLLMInferenceAdapter(config=config4)
result4 = await adapter4.should_refresh_models()
assert result4 is True, "should_refresh_models should return True when refresh_models is True"
# Test case 5: refresh_models is False, api_token is real token
config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False)
adapter5 = VLLMInferenceAdapter(config=config5)
result5 = await adapter5.should_refresh_models()
assert result5 is False, "should_refresh_models should return False when refresh_models is False"

View file

@ -44,11 +44,12 @@ def mixin():
config = RemoteInferenceProviderConfig()
mixin_instance = OpenAIMixinImpl(config=config)
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
# Mock model_store with async methods
mock_model_store = AsyncMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override
mixin_instance.model_store = mock_model_store
return mixin_instance
@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability:
assert len(mixin._model_cache) == 3
async def test_check_model_availability_with_pre_registered_model(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability returns True for pre-registered models in model_store"""
# Mock model_store.has_model to return True for a specific model
mock_model_store = AsyncMock()
mock_model_store.has_model = AsyncMock(return_value=True)
mixin.model_store = mock_model_store
# Test that pre-registered model is found without calling the provider's API
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("pre-registered-model")
# Should not call the provider's list_models since model was found in store
mock_client_with_models.models.list.assert_not_called()
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
self, mixin, mock_client_with_models, mock_client_context
):
"""Test that check_model_availability falls back to provider when model not in store"""
# Mock model_store.has_model to return False
mock_model_store = AsyncMock()
mock_model_store.has_model = AsyncMock(return_value=False)
mixin.model_store = mock_model_store
# Test that it falls back to provider's model cache
with mock_client_context(mixin, mock_client_with_models):
mock_client_with_models.models.list.assert_not_called()
assert await mixin.check_model_availability("some-mock-model-id")
# Should call the provider's list_models since model was not found in store
mock_client_with_models.models.list.assert_called_once()
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
class TestOpenAIMixinCacheBehavior:
"""Test cases for cache behavior and edge cases"""
@ -466,10 +501,16 @@ class TestOpenAIMixinModelRegistration:
assert result is None
async def test_should_refresh_models(self, mixin):
"""Test should_refresh_models method (should always return False)"""
"""Test should_refresh_models method returns config value"""
# Default config has refresh_models=False
result = await mixin.should_refresh_models()
assert result is False
config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True)
mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh)
result_with_refresh = await mixin_with_refresh.should_refresh_models()
assert result_with_refresh is True
async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context):
"""Test that errors from provider API are properly propagated during registration"""
model = Model(

View file

@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
@pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = MilvusVectorIOAdapter(
config=config,
@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
@pytest.fixture
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig(
db_path=chroma_vec_db_path,
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = ChromaVectorIOAdapter(
config=config,
@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory):
@pytest.fixture
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import uuid
config = QdrantVectorIOConfig(
db_path=qdrant_vec_db_path,
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = QdrantVectorIOAdapter(
config=config,
@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
@pytest.fixture
async def pgvector_vec_adapter(mock_inference_api, embedding_dimension):
async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = PGVectorVectorIOConfig(
host="localhost",
port=5432,
db="test_db",
user="test_user",
password="test_password",
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path):
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension):
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
kvstore=SqliteKVStoreConfig(),
kvstore=unique_kvstore_config,
)
adapter = WeaviateVectorIOAdapter(
config=config,

View file

@ -125,8 +125,15 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
provider_resource_id="test_vector_db_2",
provider_id="baz", # Same provider_id
)
await cached_disk_dist_registry.register(duplicate_vector_db)
# Now we expect a ValueError to be raised for duplicate registration
with pytest.raises(
ValueError,
match=r"Provider 'baz' is already registered.*Unregister the existing provider first before registering it again.",
):
await cached_disk_dist_registry.register(duplicate_vector_db)
# Verify the original registration is still intact
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved