diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 69d7e9b6f..716be936a 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -67,6 +67,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] + async def has_model(self, model_id: str) -> bool: + """ + Check if a model exists in the routing table. + + :param model_id: The model identifier to check + :return: True if the model exists, False otherwise + """ + try: + await lookup_model(self, model_id) + return True + except ModelNotFoundError: + return False + async def register_model( self, model_id: str, diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 0bb524f5c..8a662e6db 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -97,6 +97,8 @@ class StreamingResponseOrchestrator: self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} # Track final messages after all tool executions self.final_messages: list[OpenAIMessageParam] = [] + # mapping for annotations + self.citation_files: dict[str, str] = {} async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: # Initialize output messages @@ -126,6 +128,7 @@ class StreamingResponseOrchestrator: # Text is the default response format for chat completion so don't need to pass it # (some providers don't support non-empty response_format when tools are present) response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format + logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") completion_result = await self.inference_api.openai_chat_completion( model=self.ctx.model, messages=messages, @@ -160,7 +163,7 @@ class StreamingResponseOrchestrator: # Handle choices with no tool calls for choice in current_response.choices: if not (choice.message.tool_calls and self.ctx.response_tools): - output_messages.append(await convert_chat_choice_to_response_message(choice)) + output_messages.append(await convert_chat_choice_to_response_message(choice, self.citation_files)) # Execute tool calls and coordinate results async for stream_event in self._coordinate_tool_execution( @@ -211,6 +214,8 @@ class StreamingResponseOrchestrator: for choice in current_response.choices: next_turn_messages.append(choice.message) + logger.debug(f"Choice message content: {choice.message.content}") + logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}") if choice.message.tool_calls and self.ctx.response_tools: for tool_call in choice.message.tool_calls: @@ -470,6 +475,8 @@ class StreamingResponseOrchestrator: tool_call_log = result.final_output_message tool_response_message = result.final_input_message self.sequence_number = result.sequence_number + if result.citation_files: + self.citation_files.update(result.citation_files) if tool_call_log: output_messages.append(tool_call_log) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index b028c018b..b33b47454 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -94,7 +94,10 @@ class ToolExecutor: # Yield the final result yield ToolExecutionResult( - sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message + sequence_number=sequence_number, + final_output_message=output_message, + final_input_message=input_message, + citation_files=result.metadata.get("citation_files") if result and result.metadata else None, ) async def _execute_knowledge_search_via_vector_store( @@ -129,8 +132,6 @@ class ToolExecutor: for results in all_results: search_results.extend(results) - # Convert search results to tool result format matching memory.py - # Format the results as interleaved content similar to memory.py content_items = [] content_items.append( TextContentItem( @@ -138,27 +139,58 @@ class ToolExecutor: ) ) + unique_files = set() for i, result_item in enumerate(search_results): chunk_text = result_item.content[0].text if result_item.content else "" - metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + # Get file_id from attributes if result_item.file_id is empty + file_id = result_item.file_id or ( + result_item.attributes.get("document_id") if result_item.attributes else None + ) + metadata_text = f"document_id: {file_id}, score: {result_item.score}" if result_item.attributes: metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + + text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n" content_items.append(TextContentItem(text=text_content)) + unique_files.add(file_id) content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + + citation_instruction = "" + if unique_files: + citation_instruction = ( + " Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). " + "Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)." + ) + content_items.append( TextContentItem( - text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n', ) ) + # handling missing attributes for old versions + citation_files = {} + for result in search_results: + file_id = result.file_id + if not file_id and result.attributes: + file_id = result.attributes.get("document_id") + + filename = result.filename + if not filename and result.attributes: + filename = result.attributes.get("filename") + if not filename: + filename = "unknown" + + citation_files[file_id] = filename + return ToolInvocationResult( content=content_items, metadata={ "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], "scores": [r.score for r in search_results], + "citation_files": citation_files, }, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py index d3b5a16bd..fd5f44242 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -27,6 +27,7 @@ class ToolExecutionResult(BaseModel): sequence_number: int final_output_message: OpenAIResponseOutput | None = None final_input_message: OpenAIMessageParam | None = None + citation_files: dict[str, str] | None = None @dataclass diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 310a88298..5b013b9c4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -4,9 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import re import uuid from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContent, @@ -45,7 +47,9 @@ from llama_stack.apis.inference import ( ) -async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: +async def convert_chat_choice_to_response_message( + choice: OpenAIChoice, citation_files: dict[str, str] | None = None +) -> OpenAIResponseMessage: """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" output_content = "" if isinstance(choice.message.content, str): @@ -57,9 +61,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" ) + annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {}) + return OpenAIResponseMessage( id=f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)], status="completed", role="assistant", ) @@ -200,6 +206,53 @@ async def get_message_type_by_role(role: str): return role_to_type.get(role) +def _extract_citations_from_text( + text: str, citation_files: dict[str, str] +) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]: + """Extract citation markers from text and create annotations + + Args: + text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A] + citation_files: Dictionary mapping file_id to filename + + Returns: + Tuple of (annotations_list, clean_text_without_markers) + """ + file_id_regex = re.compile(r"<\|(?Pfile-[A-Za-z0-9_-]+)\|>") + + annotations = [] + parts = [] + total_len = 0 + last_end = 0 + + for m in file_id_regex.finditer(text): + # segment before the marker + prefix = text[last_end : m.start()] + + # drop one space if it exists (since marker is at sentence end) + if prefix.endswith(" "): + prefix = prefix[:-1] + + parts.append(prefix) + total_len += len(prefix) + + fid = m.group(1) + if fid in citation_files: + annotations.append( + OpenAIResponseAnnotationFileCitation( + file_id=fid, + filename=citation_files[fid], + index=total_len, # index points to punctuation + ) + ) + + last_end = m.end() + + parts.append(text[last_end:]) + cleaned_text = "".join(parts) + return annotations, cleaned_text + + def is_function_tool_call( tool_call: OpenAIChatCompletionToolCall, tools: list[OpenAIResponseInputTool], diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c8499a9b8..aac86a056 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -331,5 +331,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti return ToolInvocationResult( content=result.content or [], - metadata=result.metadata, + metadata={ + **(result.metadata or {}), + "citation_files": getattr(result, "citation_files", None), + }, ) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 405c134e5..5a456c7c9 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -225,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - # Cleanup if needed - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def health(self) -> HealthResponse: """ diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 26231a9b7..a433257b2 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -434,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - # nothing to do since we don't maintain a persistent connection - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def list_vector_dbs(self) -> list[VectorDB]: return [v.vector_db for v in self.cache.values()] diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 511123d6e..331e5432e 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -167,7 +167,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: - pass + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 0acc90595..029eacfe3 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -349,6 +349,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def shutdown(self) -> None: self.client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index dfdfef6eb..21c388b1d 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -390,6 +390,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco if self.conn is not None: self.conn.close() log.info("Connection to PGVector database server closed") + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db(self, vector_db: VectorDB) -> None: # Persist vector DB metadata in the KV store diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 6b386840c..021938afd 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -191,6 +191,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def shutdown(self) -> None: await self.client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 54ac6f8d3..21df3bc45 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -347,6 +347,8 @@ class WeaviateVectorIOAdapter( async def shutdown(self) -> None: for client in self.client_cache.values(): client.close() + # Clean up mixin resources (file batch tasks) + await super().shutdown() async def register_vector_db( self, diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 3c5c5b4de..cba7508a2 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -474,11 +474,17 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): async def check_model_availability(self, model: str) -> bool: """ - Check if a specific model is available from the provider's /v1/models. + Check if a specific model is available from the provider's /v1/models or pre-registered. :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: True if the model is available dynamically or pre-registered, False otherwise. """ + # First check if the model is pre-registered in the model store + if hasattr(self, "model_store") and self.model_store: + if await self.model_store.has_model(model): + return True + + # Then check the provider's dynamic model cache if not self._model_cache: await self.list_models() return model in self._model_cache diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 0d0aa25a4..2a5177f93 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -293,6 +293,19 @@ class OpenAIVectorStoreMixin(ABC): await self._resume_incomplete_batches() self._last_file_batch_cleanup_time = 0 + async def shutdown(self) -> None: + """Clean up mixin resources including background tasks.""" + # Cancel any running file batch tasks gracefully + if hasattr(self, "_file_batch_tasks"): + tasks_to_cancel = list(self._file_batch_tasks.items()) + for _, task in tasks_to_cancel: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + @abstractmethod async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete chunks from a vector store.""" @@ -587,7 +600,7 @@ class OpenAIVectorStoreMixin(ABC): content = self._chunk_to_vector_store_content(chunk) response_data_item = VectorStoreSearchResponse( - file_id=chunk.metadata.get("file_id", ""), + file_id=chunk.metadata.get("document_id", ""), filename=chunk.metadata.get("filename", ""), score=score, attributes=chunk.metadata, @@ -746,12 +759,15 @@ class OpenAIVectorStoreMixin(ABC): content = content_from_data_and_mime_type(content_response.body, mime_type) + chunk_attributes = attributes.copy() + chunk_attributes["filename"] = file_response.filename + chunks = make_overlapped_chunks( file_id, content, max_chunk_size_tokens, chunk_overlap_tokens, - attributes, + chunk_attributes, ) if not chunks: vector_store_file_object.status = "failed" diff --git a/scripts/telemetry/setup_telemetry.sh b/scripts/telemetry/setup_telemetry.sh index e0b57a354..ecdd56175 100755 --- a/scripts/telemetry/setup_telemetry.sh +++ b/scripts/telemetry/setup_telemetry.sh @@ -16,10 +16,19 @@ set -Eeuo pipefail -CONTAINER_RUNTIME=${CONTAINER_RUNTIME:-docker} -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +if command -v podman &> /dev/null; then + CONTAINER_RUNTIME="podman" +elif command -v docker &> /dev/null; then + CONTAINER_RUNTIME="docker" +else + echo "🚨 Neither Podman nor Docker could be found" + echo "Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation" + exit 1 +fi -echo "🚀 Setting up telemetry stack for Llama Stack using Podman..." +echo "🚀 Setting up telemetry stack for Llama Stack using $CONTAINER_RUNTIME..." + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" if ! command -v "$CONTAINER_RUNTIME" &> /dev/null; then echo "🚨 $CONTAINER_RUNTIME could not be found" diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 54a9dd72e..a1c3d1e95 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -201,6 +201,12 @@ async def test_models_routing_table(cached_disk_dist_registry): non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None + # Test has_model + assert await table.has_model("test_provider/test-model") + assert await table.has_model("test_provider/test-model-2") + assert not await table.has_model("non-existent-model") + assert not await table.has_model("test_provider/non-existent-model") + await table.unregister_model(model_id="test_provider/test-model") await table.unregister_model(model_id="test_provider/test-model-2") diff --git a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py index 187540f82..2698b88c8 100644 --- a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py @@ -8,6 +8,7 @@ import pytest from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, @@ -35,6 +36,7 @@ from llama_stack.apis.inference import ( OpenAIUserMessageParam, ) from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + _extract_citations_from_text, convert_chat_choice_to_response_message, convert_response_content_to_chat_content, convert_response_input_to_chat_messages, @@ -340,3 +342,26 @@ class TestIsFunctionToolCall: result = is_function_tool_call(tool_call, tools) assert result is False + + +class TestExtractCitationsFromText: + def test_extract_citations_and_annotations(self): + text = "Start [not-a-file]. New source <|file-abc123|>. " + text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation." + file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"} + + annotations, cleaned_text = _extract_citations_from_text(text, file_mapping) + + expected_annotations = [ + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30), + OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44), + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59), + ] + expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation." + + assert cleaned_text == expected_clean_text + assert annotations == expected_annotations + # OpenAI cites at the end of the sentence + assert cleaned_text[expected_annotations[0].index] == "." + assert cleaned_text[expected_annotations[1].index] == "?" + assert cleaned_text[expected_annotations[2].index] == "!" diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 2e3a62ca6..ad9406951 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -44,11 +44,12 @@ def mixin(): config = RemoteInferenceProviderConfig() mixin_instance = OpenAIMixinImpl(config=config) - # just enough to satisfy _get_provider_model_id calls - mock_model_store = MagicMock() + # Mock model_store with async methods + mock_model_store = AsyncMock() mock_model = MagicMock() mock_model.provider_resource_id = "test-provider-resource-id" mock_model_store.get_model = AsyncMock(return_value=mock_model) + mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override mixin_instance.model_store = mock_model_store return mixin_instance @@ -189,6 +190,40 @@ class TestOpenAIMixinCheckModelAvailability: assert len(mixin._model_cache) == 3 + async def test_check_model_availability_with_pre_registered_model( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability returns True for pre-registered models in model_store""" + # Mock model_store.has_model to return True for a specific model + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=True) + mixin.model_store = mock_model_store + + # Test that pre-registered model is found without calling the provider's API + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("pre-registered-model") + # Should not call the provider's list_models since model was found in store + mock_client_with_models.models.list.assert_not_called() + mock_model_store.has_model.assert_called_once_with("pre-registered-model") + + async def test_check_model_availability_fallback_to_provider_when_not_in_store( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability falls back to provider when model not in store""" + # Mock model_store.has_model to return False + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=False) + mixin.model_store = mock_model_store + + # Test that it falls back to provider's model cache + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("some-mock-model-id") + # Should call the provider's list_models since model was not found in store + mock_client_with_models.models.list.assert_called_once() + mock_model_store.has_model.assert_called_once_with("some-mock-model-id") + class TestOpenAIMixinCacheBehavior: """Test cases for cache behavior and edge cases""" diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 70ace695e..d122f9323 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -145,10 +145,10 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): @pytest.fixture -async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension): +async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): config = SQLiteVectorIOConfig( db_path=sqlite_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = SQLiteVecVectorIOAdapter( config=config, @@ -187,10 +187,10 @@ async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): @pytest.fixture -async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): +async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api): config = MilvusVectorIOConfig( db_path=milvus_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = MilvusVectorIOAdapter( config=config, @@ -264,10 +264,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension): @pytest.fixture -async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension): +async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): config = ChromaVectorIOConfig( db_path=chroma_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = ChromaVectorIOAdapter( config=config, @@ -296,12 +296,12 @@ def qdrant_vec_db_path(tmp_path_factory): @pytest.fixture -async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension): +async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): import uuid config = QdrantVectorIOConfig( db_path=qdrant_vec_db_path, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = QdrantVectorIOAdapter( config=config, @@ -386,14 +386,14 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection): @pytest.fixture -async def pgvector_vec_adapter(mock_inference_api, embedding_dimension): +async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension): config = PGVectorVectorIOConfig( host="localhost", port=5432, db="test_db", user="test_user", password="test_password", - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None) @@ -476,7 +476,7 @@ async def weaviate_vec_index(weaviate_vec_db_path): @pytest.fixture -async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embedding_dimension): +async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension): import pytest_socket import weaviate @@ -492,7 +492,7 @@ async def weaviate_vec_adapter(weaviate_vec_db_path, mock_inference_api, embeddi config = WeaviateVectorIOConfig( weaviate_cluster_url="localhost:8080", weaviate_api_key=None, - kvstore=SqliteKVStoreConfig(), + kvstore=unique_kvstore_config, ) adapter = WeaviateVectorIOAdapter( config=config,