Merge branch 'main' into rag-metadata-support

This commit is contained in:
Francisco Arceo 2025-05-12 10:10:28 -06:00 committed by GitHub
commit 2e70782e63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 33 additions and 8 deletions

View file

@ -7,7 +7,7 @@
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
### ✨🎉 Llama 4 Support 🎉✨
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.

View file

@ -415,6 +415,7 @@ class Agents(Protocol):
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -606,3 +607,4 @@ class Agents(Protocol):
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
"""
...

View file

@ -95,6 +95,7 @@ class Eval(Protocol):
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
@ -112,6 +113,7 @@ class Eval(Protocol):
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
@ -140,3 +142,4 @@ class Eval(Protocol):
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""
...

View file

@ -99,7 +99,7 @@ class ProviderImpl(Providers):
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except asyncio.TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
return (
api_name,
HealthResponse(

View file

@ -630,7 +630,7 @@ class InferenceRouter(Inference):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except asyncio.TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",

View file

@ -114,7 +114,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
@ -139,7 +139,7 @@ async def shutdown(app):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
except (asyncio.TimeoutError, TimeoutError):
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})

View file

@ -106,7 +106,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
if not vector_db_ids:
return RAGQueryResult(content=None)
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(

View file

@ -26,8 +26,7 @@ from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = logging.getLogger(__name__)
ChromaClientType = chromadb.AsyncHttpClient | chromadb.PersistentClient
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
# this is a helper to allow us to use async and non-async chroma clients interchangeably

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock
import pytest
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
class TestRagQuery:
@pytest.mark.asyncio
async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])