diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 98b495de2..c81f9b33d 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -8069,9 +8069,6 @@
}
},
"additionalProperties": false,
- "required": [
- "content"
- ],
"title": "ToolInvocationResult"
},
"IterrowsResponse": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 321dfe8e0..8ea0e1b9c 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -5529,8 +5529,6 @@ components:
- type: array
- type: object
additionalProperties: false
- required:
- - content
title: ToolInvocationResult
IterrowsResponse:
type: object
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index a4d84edbe..e0744a75e 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type
class ToolInvocationResult(BaseModel):
- content: InterleavedContent
+ content: Optional[InterleavedContent] = None
error_message: Optional[str] = None
error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
@@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
@runtime_checkable
@trace_protocol
class ToolRuntime(Protocol):
- tool_store: ToolStore
+ tool_store: ToolStore | None = None
- rag_tool: RAGToolRuntime
+ rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py
index 2bbb3bce8..ab0a4a20a 100644
--- a/llama_stack/apis/vector_io/vector_io.py
+++ b/llama_stack/apis/vector_io/vector_io.py
@@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable
@trace_protocol
class VectorIO(Protocol):
- vector_db_store: VectorDBStore
+ vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion
diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py
index 2905e2f6a..23468c5d0 100644
--- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py
+++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py
@@ -6,12 +6,14 @@
from typing import Any, Dict
+from llama_stack.distribution.datatypes import Api
+
from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"]
-async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
+async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps)
diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py
index 42b538876..b909d32ef 100644
--- a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py
+++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py
@@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
"""Shutdown the processor."""
pass
- def force_flush(self, timeout_millis: float = None) -> bool:
+ def force_flush(self, timeout_millis: float | None = None) -> bool:
"""Force flush any pending spans."""
return True
diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py
index 4cdb420b2..766bc0fc0 100644
--- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py
+++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py
@@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
from .config import TelemetryConfig, TelemetrySink
-_GLOBAL_STORAGE = {
+_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {},
"counters": {},
"gauges": {},
@@ -70,7 +70,7 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
- def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
+ def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
@@ -146,7 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
- **event.attributes,
+ **(event.attributes or {}),
},
timestamp=timestamp_ns,
)
@@ -154,6 +154,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
+ assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
@@ -163,6 +164,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
+ assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
@@ -182,6 +184,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
+ assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name,
diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py
index 1850d69f7..9c5f642ea 100644
--- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py
+++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py
@@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
)
-_subprocess.Popen = popen_not_allowed
+_subprocess.Popen = popen_not_allowed # type: ignore
import atexit as _atexit
@@ -104,7 +104,7 @@ def _open_connections():
return _NETWORK_CONNECTIONS
-_builtins._open_connections = _open_connections
+_builtins._open_connections = _open_connections # type: ignore
@_atexit.register
diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py
index 810591c1c..6106cf741 100644
--- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py
+++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_execution.py
@@ -161,9 +161,9 @@ _set_seeds()\
def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"]
# Convert the base64 string to a bytes object
- images = [base64.b64decode(d["image_base64"]) for d in image_data]
+ images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects
- images = [Image.open(BytesIO(img)) for img in images]
+ images = [Image.open(BytesIO(img)) for img in images_raw]
# Create a list of image paths
image_paths = []
for i, img in enumerate(images):
diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py
index 15118c9df..0ef3c35e9 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py
@@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig
-async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
+async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index 4b3f7d9e7..8dd846c6f 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -15,6 +15,7 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
+ InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import Inference
@@ -23,6 +24,7 @@ from llama_stack.apis.tools import (
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
+ Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
@@ -62,6 +64,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self):
pass
+ async def register_tool(self, tool: Tool) -> None:
+ pass
+
+ async def unregister_tool(self, tool_id: str) -> None:
+ return
+
async def insert(
self,
documents: List[RAGDocument],
@@ -121,11 +129,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
return RAGQueryResult(content=None)
# sort by score
- chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
+ chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks]
tokens = 0
- picked = [
+ picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index 0c8718cb8..20c795650 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -15,11 +15,13 @@ import faiss
import numpy as np
from numpy.typing import NDArray
-from llama_stack.apis.inference import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
-from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
+from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
+from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
class FaissIndex(EmbeddingIndex):
- chunk_by_index: Dict[int, str]
-
- def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
+ def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension)
- self.chunk_by_index = {}
+ self.chunk_by_index: dict[int, Chunk] = {}
self.kvstore = kvstore
self.bank_id = bank_id
@classmethod
- async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
+ async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
@@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
- def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
+ def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
- self.cache = {}
- self.kvstore = None
+ self.cache: dict[str, VectorDBWithIndex] = {}
+ self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self,
vector_db: VectorDB,
) -> None:
+ assert self.kvstore is not None
+
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(
key=key,
@@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return [i.vector_db for i in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
+ assert self.kvstore is not None
+
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
index 17865c93e..b8f6f602f 100644
--- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
+++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
@@ -15,9 +15,10 @@ import numpy as np
import sqlite_vec
from numpy.typing import NDArray
+from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
-from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
+from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__)
@@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency.
"""
+ assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
+
cur = self.connection.cursor()
try:
# Start transaction
@@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
+ if isinstance(chunk.content, str)
]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany(
@@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
+ if isinstance(chunk.content, str)
]
# Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
@@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
- def __init__(self, config, inference_api: Api.inference) -> None:
+ def __init__(self, config, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {}
diff --git a/llama_stack/providers/utils/telemetry/dataset_mixin.py b/llama_stack/providers/utils/telemetry/dataset_mixin.py
index 0cb695956..34c612133 100644
--- a/llama_stack/providers/utils/telemetry/dataset_mixin.py
+++ b/llama_stack/providers/utils/telemetry/dataset_mixin.py
@@ -13,7 +13,7 @@ from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
class TelemetryDatasetMixin:
"""Mixin class that provides dataset-related functionality for telemetry providers."""
- datasetio_api: DatasetIO
+ datasetio_api: DatasetIO | None
async def save_spans_to_dataset(
self,
diff --git a/pyproject.toml b/pyproject.toml
index 107150cee..fb42f6725 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -235,16 +235,6 @@ exclude = [
"^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/",
"^llama_stack/providers/inline/scoring/llm_as_judge/",
- "^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$",
- "^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$",
- "^llama_stack/providers/inline/telemetry/sample/",
- "^llama_stack/providers/inline/tool_runtime/code_interpreter/",
- "^llama_stack/providers/inline/tool_runtime/rag/",
- "^llama_stack/providers/inline/vector_io/chroma/",
- "^llama_stack/providers/inline/vector_io/faiss/",
- "^llama_stack/providers/inline/vector_io/milvus/",
- "^llama_stack/providers/inline/vector_io/qdrant/",
- "^llama_stack/providers/inline/vector_io/sqlite_vec/",
"^llama_stack/providers/remote/agents/sample/",
"^llama_stack/providers/remote/datasetio/huggingface/",
"^llama_stack/providers/remote/inference/anthropic/",