chore: mypy violations cleanup for inline::{telemetry,tool_runtime,vector_io} (#1711)

# What does this PR do?

Clean up mypy violations for inline::{telemetry,tool_runtime,vector_io}.
This also makes API accept a tool call result without any content (like
RAG tool already may produce).

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-20 13:01:10 -04:00 committed by GitHub
parent 355134f51d
commit 515c16e352
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 51 additions and 44 deletions

View file

@ -8069,9 +8069,6 @@
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [
"content"
],
"title": "ToolInvocationResult" "title": "ToolInvocationResult"
}, },
"IterrowsResponse": { "IterrowsResponse": {

View file

@ -5529,8 +5529,6 @@ components:
- type: array - type: array
- type: object - type: object
additionalProperties: false additionalProperties: false
required:
- content
title: ToolInvocationResult title: ToolInvocationResult
IterrowsResponse: IterrowsResponse:
type: object type: object

View file

@ -69,7 +69,7 @@ class ToolGroup(Resource):
@json_schema_type @json_schema_type
class ToolInvocationResult(BaseModel): class ToolInvocationResult(BaseModel):
content: InterleavedContent content: Optional[InterleavedContent] = None
error_message: Optional[str] = None error_message: Optional[str] = None
error_code: Optional[int] = None error_code: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class ToolRuntime(Protocol): class ToolRuntime(Protocol):
tool_store: ToolStore tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed. # TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")

View file

@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class VectorIO(Protocol): class VectorIO(Protocol):
vector_db_store: VectorDBStore vector_db_store: VectorDBStore | None = None
# this will just block now until chunks are inserted, but it should # this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion # probably return a Job instance which can be polled for completion

View file

@ -6,12 +6,14 @@
from typing import Any, Dict from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"] __all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
from .telemetry import TelemetryAdapter from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps) impl = TelemetryAdapter(config, deps)

View file

@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
"""Shutdown the processor.""" """Shutdown the processor."""
pass pass
def force_flush(self, timeout_millis: float = None) -> bool: def force_flush(self, timeout_millis: float | None = None) -> bool:
"""Force flush any pending spans.""" """Force flush any pending spans."""
return True return True

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = { _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
"active_spans": {}, "active_spans": {},
"counters": {}, "counters": {},
"gauges": {}, "gauges": {},
@ -70,7 +70,7 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
self.config = config self.config = config
self.datasetio_api = deps.get(Api.datasetio) self.datasetio_api = deps.get(Api.datasetio)
self.meter = None self.meter = None
@ -146,7 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
"message": event.message, "message": event.message,
"severity": event.severity.value, "severity": event.severity.value,
"__ttl__": ttl_seconds, "__ttl__": ttl_seconds,
**event.attributes, **(event.attributes or {}),
}, },
timestamp=timestamp_ns, timestamp=timestamp_ns,
) )
@ -154,6 +154,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}") print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["counters"]: if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name, name=name,
@ -163,6 +164,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["counters"][name] return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["gauges"]: if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name, name=name,
@ -182,6 +184,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
up_down_counter.add(event.value, attributes=event.attributes) up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter: def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
assert self.meter is not None
if name not in _GLOBAL_STORAGE["up_down_counters"]: if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter( _GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
name=name, name=name,

View file

@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
) )
_subprocess.Popen = popen_not_allowed _subprocess.Popen = popen_not_allowed # type: ignore
import atexit as _atexit import atexit as _atexit
@ -104,7 +104,7 @@ def _open_connections():
return _NETWORK_CONNECTIONS return _NETWORK_CONNECTIONS
_builtins._open_connections = _open_connections _builtins._open_connections = _open_connections # type: ignore
@_atexit.register @_atexit.register

View file

@ -161,9 +161,9 @@ _set_seeds()\
def process_matplotlib_response(response, matplotlib_dump_dir: str): def process_matplotlib_response(response, matplotlib_dump_dir: str):
image_data = response["image_data"] image_data = response["image_data"]
# Convert the base64 string to a bytes object # Convert the base64 string to a bytes object
images = [base64.b64decode(d["image_base64"]) for d in image_data] images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
# Create a list of PIL images from the bytes objects # Create a list of PIL images from the bytes objects
images = [Image.open(BytesIO(img)) for img in images] images = [Image.open(BytesIO(img)) for img in images_raw]
# Create a list of image paths # Create a list of image paths
image_paths = [] image_paths = []
for i, img in enumerate(images): for i, img in enumerate(images):

View file

@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import Api
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]): async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])

View file

@ -15,6 +15,7 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
@ -23,6 +24,7 @@ from llama_stack.apis.tools import (
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
RAGToolRuntime, RAGToolRuntime,
Tool,
ToolDef, ToolDef,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
@ -62,6 +64,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self): async def shutdown(self):
pass pass
async def register_tool(self, tool: Tool) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
return
async def insert( async def insert(
self, self,
documents: List[RAGDocument], documents: List[RAGDocument],
@ -121,11 +129,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
return RAGQueryResult(content=None) return RAGQueryResult(content=None)
# sort by score # sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
chunks = chunks[: query_config.max_chunks] chunks = chunks[: query_config.max_chunks]
tokens = 0 tokens = 0
picked = [ picked: list[InterleavedContentItem] = [
TextContentItem( TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
) )

View file

@ -15,11 +15,13 @@ import faiss
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
class FaissIndex(EmbeddingIndex): class FaissIndex(EmbeddingIndex):
chunk_by_index: Dict[int, str] def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexFlatL2(dimension)
self.chunk_by_index = {} self.chunk_by_index: dict[int, Chunk] = {}
self.kvstore = kvstore self.kvstore = kvstore
self.bank_id = bank_id self.bank_id = bank_id
@classmethod @classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None): async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id) instance = cls(dimension, kvstore, bank_id)
await instance.initialize() await instance.initialize()
return instance return instance
@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None: def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.cache = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore = None self.kvstore: KVStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self, self,
vector_db: VectorDB, vector_db: VectorDB,
) -> None: ) -> None:
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return [i.vector_db for i in self.cache.values()] return [i.vector_db for i in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
assert self.kvstore is not None
if vector_db_id not in self.cache: if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found") logger.warning(f"Vector DB {vector_db_id} not found")
return return

View file

@ -15,9 +15,10 @@ import numpy as np
import sqlite_vec import sqlite_vec
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding (serialized to raw bytes) into the virtual table using the assigned rowid. embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency. If any insert fails, the transaction is rolled back to maintain consistency.
""" """
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
cur = self.connection.cursor() cur = self.connection.cursor()
try: try:
# Start transaction # Start transaction
@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [ metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks for chunk in batch_chunks
if isinstance(chunk.content, str)
] ]
# Insert metadata (ON CONFLICT to avoid duplicates) # Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany( cur.executemany(
@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [ embedding_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist())) (generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
if isinstance(chunk.content, str)
] ]
# Insert embeddings in batch # Insert embeddings in batch
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
""" """
def __init__(self, config, inference_api: Api.inference) -> None: def __init__(self, config, inference_api: Inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.cache: Dict[str, VectorDBWithIndex] = {} self.cache: Dict[str, VectorDBWithIndex] = {}

View file

@ -13,7 +13,7 @@ from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
class TelemetryDatasetMixin: class TelemetryDatasetMixin:
"""Mixin class that provides dataset-related functionality for telemetry providers.""" """Mixin class that provides dataset-related functionality for telemetry providers."""
datasetio_api: DatasetIO datasetio_api: DatasetIO | None
async def save_spans_to_dataset( async def save_spans_to_dataset(
self, self,

View file

@ -235,16 +235,6 @@ exclude = [
"^llama_stack/providers/inline/scoring/basic/", "^llama_stack/providers/inline/scoring/basic/",
"^llama_stack/providers/inline/scoring/braintrust/", "^llama_stack/providers/inline/scoring/braintrust/",
"^llama_stack/providers/inline/scoring/llm_as_judge/", "^llama_stack/providers/inline/scoring/llm_as_judge/",
"^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$",
"^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$",
"^llama_stack/providers/inline/telemetry/sample/",
"^llama_stack/providers/inline/tool_runtime/code_interpreter/",
"^llama_stack/providers/inline/tool_runtime/rag/",
"^llama_stack/providers/inline/vector_io/chroma/",
"^llama_stack/providers/inline/vector_io/faiss/",
"^llama_stack/providers/inline/vector_io/milvus/",
"^llama_stack/providers/inline/vector_io/qdrant/",
"^llama_stack/providers/inline/vector_io/sqlite_vec/",
"^llama_stack/providers/remote/agents/sample/", "^llama_stack/providers/remote/agents/sample/",
"^llama_stack/providers/remote/datasetio/huggingface/", "^llama_stack/providers/remote/datasetio/huggingface/",
"^llama_stack/providers/remote/inference/anthropic/", "^llama_stack/providers/remote/inference/anthropic/",