diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 98b495de2..c81f9b33d 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8069,9 +8069,6 @@ } }, "additionalProperties": false, - "required": [ - "content" - ], "title": "ToolInvocationResult" }, "IterrowsResponse": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 321dfe8e0..8ea0e1b9c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5529,8 +5529,6 @@ components: - type: array - type: object additionalProperties: false - required: - - content title: ToolInvocationResult IterrowsResponse: type: object diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index a4d84edbe..e0744a75e 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -69,7 +69,7 @@ class ToolGroup(Resource): @json_schema_type class ToolInvocationResult(BaseModel): - content: InterleavedContent + content: Optional[InterleavedContent] = None error_message: Optional[str] = None error_code: Optional[int] = None metadata: Optional[Dict[str, Any]] = None @@ -140,9 +140,9 @@ class SpecialToolGroup(Enum): @runtime_checkable @trace_protocol class ToolRuntime(Protocol): - tool_store: ToolStore + tool_store: ToolStore | None = None - rag_tool: RAGToolRuntime + rag_tool: RAGToolRuntime | None = None # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. @webmethod(route="/tool-runtime/list-tools", method="GET") diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 2bbb3bce8..ab0a4a20a 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -36,7 +36,7 @@ class VectorDBStore(Protocol): @runtime_checkable @trace_protocol class VectorIO(Protocol): - vector_db_store: VectorDBStore + vector_db_store: VectorDBStore | None = None # this will just block now until chunks are inserted, but it should # probably return a Job instance which can be polled for completion diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py index 2905e2f6a..23468c5d0 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -6,12 +6,14 @@ from typing import Any, Dict +from llama_stack.distribution.datatypes import Api + from .config import TelemetryConfig, TelemetrySink __all__ = ["TelemetryConfig", "TelemetrySink"] -async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]): from .telemetry import TelemetryAdapter impl = TelemetryAdapter(config, deps) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py index 42b538876..b909d32ef 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor): """Shutdown the processor.""" pass - def force_flush(self, timeout_millis: float = None) -> bool: + def force_flush(self, timeout_millis: float | None = None) -> bool: """Force flush any pending spans.""" return True diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 4cdb420b2..766bc0fc0 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace from .config import TelemetryConfig, TelemetrySink -_GLOBAL_STORAGE = { +_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { "active_spans": {}, "counters": {}, "gauges": {}, @@ -70,7 +70,7 @@ def is_tracing_enabled(tracer): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): - def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: + def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None: self.config = config self.datasetio_api = deps.get(Api.datasetio) self.meter = None @@ -146,7 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): "message": event.message, "severity": event.severity.value, "__ttl__": ttl_seconds, - **event.attributes, + **(event.attributes or {}), }, timestamp=timestamp_ns, ) @@ -154,6 +154,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}") def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + assert self.meter is not None if name not in _GLOBAL_STORAGE["counters"]: _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( name=name, @@ -163,6 +164,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["counters"][name] def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + assert self.meter is not None if name not in _GLOBAL_STORAGE["gauges"]: _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( name=name, @@ -182,6 +184,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): up_down_counter.add(event.value, attributes=event.attributes) def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: + assert self.meter is not None if name not in _GLOBAL_STORAGE["up_down_counters"]: _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( name=name, diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py index 1850d69f7..9c5f642ea 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py @@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs): ) -_subprocess.Popen = popen_not_allowed +_subprocess.Popen = popen_not_allowed # type: ignore import atexit as _atexit @@ -104,7 +104,7 @@ def _open_connections(): return _NETWORK_CONNECTIONS -_builtins._open_connections = _open_connections +_builtins._open_connections = _open_connections # type: ignore @_atexit.register diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py index 810591c1c..6106cf741 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py @@ -161,9 +161,9 @@ _set_seeds()\ def process_matplotlib_response(response, matplotlib_dump_dir: str): image_data = response["image_data"] # Convert the base64 string to a bytes object - images = [base64.b64decode(d["image_base64"]) for d in image_data] + images_raw = [base64.b64decode(d["image_base64"]) for d in image_data] # Create a list of PIL images from the bytes objects - images = [Image.open(BytesIO(img)) for img in images] + images = [Image.open(BytesIO(img)) for img in images_raw] # Create a list of image paths image_paths = [] for i, img in enumerate(images): diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index 15118c9df..0ef3c35e9 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import Api from .config import RagToolRuntimeConfig -async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]): +async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]): from .memory import MemoryToolRuntimeImpl impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 4b3f7d9e7..8dd846c6f 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -15,6 +15,7 @@ from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( URL, InterleavedContent, + InterleavedContentItem, TextContentItem, ) from llama_stack.apis.inference import Inference @@ -23,6 +24,7 @@ from llama_stack.apis.tools import ( RAGQueryConfig, RAGQueryResult, RAGToolRuntime, + Tool, ToolDef, ToolInvocationResult, ToolParameter, @@ -62,6 +64,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): async def shutdown(self): pass + async def register_tool(self, tool: Tool) -> None: + pass + + async def unregister_tool(self, tool_id: str) -> None: + return + async def insert( self, documents: List[RAGDocument], @@ -121,11 +129,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): return RAGQueryResult(content=None) # sort by score - chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) + chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore chunks = chunks[: query_config.max_chunks] tokens = 0 - picked = [ + picked: list[InterleavedContentItem] = [ TextContentItem( text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" ) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 0c8718cb8..20c795650 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -15,11 +15,13 @@ import faiss import numpy as np from numpy.typing import NDArray -from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference.inference import Inference from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::" class FaissIndex(EmbeddingIndex): - chunk_by_index: Dict[int, str] - - def __init__(self, dimension: int, kvstore=None, bank_id: str = None): + def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): self.index = faiss.IndexFlatL2(dimension) - self.chunk_by_index = {} + self.chunk_by_index: dict[int, Chunk] = {} self.kvstore = kvstore self.bank_id = bank_id @classmethod - async def create(cls, dimension: int, kvstore=None, bank_id: str = None): + async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): instance = cls(dimension, kvstore, bank_id) await instance.initialize() return instance @@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex): class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None: + def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: self.config = config self.inference_api = inference_api - self.cache = {} - self.kvstore = None + self.cache: dict[str, VectorDBWithIndex] = {} + self.kvstore: KVStore | None = None async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) @@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db: VectorDB, ) -> None: + assert self.kvstore is not None + key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" await self.kvstore.set( key=key, @@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): return [i.vector_db for i in self.cache.values()] async def unregister_vector_db(self, vector_db_id: str) -> None: + assert self.kvstore is not None + if vector_db_id not in self.cache: logger.warning(f"Vector DB {vector_db_id} not found") return diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 17865c93e..b8f6f602f 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -15,9 +15,10 @@ import numpy as np import sqlite_vec from numpy.typing import NDArray +from llama_stack.apis.inference.inference import Inference from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex logger = logging.getLogger(__name__) @@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex): embedding (serialized to raw bytes) into the virtual table using the assigned rowid. If any insert fails, the transaction is rolled back to maintain consistency. """ + assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" + cur = self.connection.cursor() try: # Start transaction @@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex): metadata_data = [ (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) for chunk in batch_chunks + if isinstance(chunk.content, str) ] # Insert metadata (ON CONFLICT to avoid duplicates) cur.executemany( @@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex): embedding_data = [ (generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist())) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) + if isinstance(chunk.content, str) ] # Insert embeddings in batch cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) @@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). """ - def __init__(self, config, inference_api: Api.inference) -> None: + def __init__(self, config, inference_api: Inference) -> None: self.config = config self.inference_api = inference_api self.cache: Dict[str, VectorDBWithIndex] = {} diff --git a/llama_stack/providers/utils/telemetry/dataset_mixin.py b/llama_stack/providers/utils/telemetry/dataset_mixin.py index 0cb695956..34c612133 100644 --- a/llama_stack/providers/utils/telemetry/dataset_mixin.py +++ b/llama_stack/providers/utils/telemetry/dataset_mixin.py @@ -13,7 +13,7 @@ from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span class TelemetryDatasetMixin: """Mixin class that provides dataset-related functionality for telemetry providers.""" - datasetio_api: DatasetIO + datasetio_api: DatasetIO | None async def save_spans_to_dataset( self, diff --git a/pyproject.toml b/pyproject.toml index 107150cee..fb42f6725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,16 +235,6 @@ exclude = [ "^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/llm_as_judge/", - "^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$", - "^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$", - "^llama_stack/providers/inline/telemetry/sample/", - "^llama_stack/providers/inline/tool_runtime/code_interpreter/", - "^llama_stack/providers/inline/tool_runtime/rag/", - "^llama_stack/providers/inline/vector_io/chroma/", - "^llama_stack/providers/inline/vector_io/faiss/", - "^llama_stack/providers/inline/vector_io/milvus/", - "^llama_stack/providers/inline/vector_io/qdrant/", - "^llama_stack/providers/inline/vector_io/sqlite_vec/", "^llama_stack/providers/remote/agents/sample/", "^llama_stack/providers/remote/datasetio/huggingface/", "^llama_stack/providers/remote/inference/anthropic/",