diff --git a/README.md b/README.md index b2b2d12d9..5dfe3577a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) -[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) +[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) ### ✨🎉 Llama 4 Support 🎉✨ We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta. diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 91e57dbbe..f4367d09b 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -415,6 +415,7 @@ class Agents(Protocol): :returns: If stream=False, returns a Turn object. If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk """ + ... @webmethod( route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", @@ -606,3 +607,4 @@ class Agents(Protocol): :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. """ + ... diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 23ca89a94..38699d3f5 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -95,6 +95,7 @@ class Eval(Protocol): :param benchmark_config: The configuration for the benchmark. :return: The job that was created to run the evaluation. """ + ... @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") async def evaluate_rows( @@ -112,6 +113,7 @@ class Eval(Protocol): :param benchmark_config: The configuration for the benchmark. :return: EvaluateResponse object containing generations and scores """ + ... @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") async def job_status(self, benchmark_id: str, job_id: str) -> Job: @@ -140,3 +142,4 @@ class Eval(Protocol): :param job_id: The ID of the job to get the result of. :return: The result of the job. """ + ... diff --git a/llama_stack/distribution/providers.py b/llama_stack/distribution/providers.py index 157fd1e0e..29b7109dd 100644 --- a/llama_stack/distribution/providers.py +++ b/llama_stack/distribution/providers.py @@ -99,7 +99,7 @@ class ProviderImpl(Providers): try: health = await asyncio.wait_for(impl.health(), timeout=timeout) return api_name, health - except asyncio.TimeoutError: + except (asyncio.TimeoutError, TimeoutError): return ( api_name, HealthResponse( diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 7e80a067f..371d34904 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -630,7 +630,7 @@ class InferenceRouter(Inference): continue health = await asyncio.wait_for(impl.health(), timeout=timeout) health_statuses[provider_id] = health - except asyncio.TimeoutError: + except (asyncio.TimeoutError, TimeoutError): health_statuses[provider_id] = HealthResponse( status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds", diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 85c576753..e34a62b00 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -114,7 +114,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=400, detail=str(exc)) elif isinstance(exc, PermissionError): return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") - elif isinstance(exc, TimeoutError): + elif isinstance(exc, asyncio.TimeoutError | TimeoutError): return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") @@ -139,7 +139,7 @@ async def shutdown(app): await asyncio.wait_for(impl.shutdown(), timeout=5) else: logger.warning("No shutdown method for %s", impl_name) - except asyncio.TimeoutError: + except (asyncio.TimeoutError, TimeoutError): logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) except (Exception, asyncio.CancelledError) as e: logger.exception("Failed to shutdown %s: %s", impl_name, {e}) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index c290e89ec..b67bfcf4a 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -106,7 +106,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: if not vector_db_ids: - return RAGQueryResult(content=None) + raise ValueError( + "No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID." + ) query_config = query_config or RAGQueryConfig() query = await generate_rag_query( diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 5381a48ef..a919963ab 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -26,8 +26,7 @@ from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig log = logging.getLogger(__name__) - -ChromaClientType = chromadb.AsyncHttpClient | chromadb.PersistentClient +ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI # this is a helper to allow us to use async and non-async chroma clients interchangeably diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py new file mode 100644 index 000000000..b9fd8cca4 --- /dev/null +++ b/tests/unit/rag/test_rag_query.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock + +import pytest + +from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl + + +class TestRagQuery: + @pytest.mark.asyncio + async def test_query_raises_on_empty_vector_db_ids(self): + rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + with pytest.raises(ValueError): + await rag_tool.query(content=MagicMock(), vector_db_ids=[])