mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
chore: mypy violations cleanup for inline::{telemetry,tool_runtime,vector_io} (#1711)
# What does this PR do? Clean up mypy violations for inline::{telemetry,tool_runtime,vector_io}. This also makes API accept a tool call result without any content (like RAG tool already may produce). Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
355134f51d
commit
515c16e352
15 changed files with 51 additions and 44 deletions
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -8069,9 +8069,6 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
|
||||||
"content"
|
|
||||||
],
|
|
||||||
"title": "ToolInvocationResult"
|
"title": "ToolInvocationResult"
|
||||||
},
|
},
|
||||||
"IterrowsResponse": {
|
"IterrowsResponse": {
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5529,8 +5529,6 @@ components:
|
||||||
- type: array
|
- type: array
|
||||||
- type: object
|
- type: object
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
|
||||||
- content
|
|
||||||
title: ToolInvocationResult
|
title: ToolInvocationResult
|
||||||
IterrowsResponse:
|
IterrowsResponse:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -69,7 +69,7 @@ class ToolGroup(Resource):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolInvocationResult(BaseModel):
|
class ToolInvocationResult(BaseModel):
|
||||||
content: InterleavedContent
|
content: Optional[InterleavedContent] = None
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
error_code: Optional[int] = None
|
error_code: Optional[int] = None
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore | None = None
|
||||||
|
|
||||||
rag_tool: RAGToolRuntime
|
rag_tool: RAGToolRuntime | None = None
|
||||||
|
|
||||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
|
|
|
@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class VectorIO(Protocol):
|
class VectorIO(Protocol):
|
||||||
vector_db_store: VectorDBStore
|
vector_db_store: VectorDBStore | None = None
|
||||||
|
|
||||||
# this will just block now until chunks are inserted, but it should
|
# this will just block now until chunks are inserted, but it should
|
||||||
# probably return a Job instance which can be polled for completion
|
# probably return a Job instance which can be polled for completion
|
||||||
|
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from .config import TelemetryConfig, TelemetrySink
|
from .config import TelemetryConfig, TelemetrySink
|
||||||
|
|
||||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
|
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
|
||||||
from .telemetry import TelemetryAdapter
|
from .telemetry import TelemetryAdapter
|
||||||
|
|
||||||
impl = TelemetryAdapter(config, deps)
|
impl = TelemetryAdapter(config, deps)
|
||||||
|
|
|
@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
"""Shutdown the processor."""
|
"""Shutdown the processor."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def force_flush(self, timeout_millis: float = None) -> bool:
|
def force_flush(self, timeout_millis: float | None = None) -> bool:
|
||||||
"""Force flush any pending spans."""
|
"""Force flush any pending spans."""
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
|
||||||
|
|
||||||
from .config import TelemetryConfig, TelemetrySink
|
from .config import TelemetryConfig, TelemetrySink
|
||||||
|
|
||||||
_GLOBAL_STORAGE = {
|
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||||
"active_spans": {},
|
"active_spans": {},
|
||||||
"counters": {},
|
"counters": {},
|
||||||
"gauges": {},
|
"gauges": {},
|
||||||
|
@ -70,7 +70,7 @@ def is_tracing_enabled(tracer):
|
||||||
|
|
||||||
|
|
||||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = deps.get(Api.datasetio)
|
self.datasetio_api = deps.get(Api.datasetio)
|
||||||
self.meter = None
|
self.meter = None
|
||||||
|
@ -146,7 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
"message": event.message,
|
"message": event.message,
|
||||||
"severity": event.severity.value,
|
"severity": event.severity.value,
|
||||||
"__ttl__": ttl_seconds,
|
"__ttl__": ttl_seconds,
|
||||||
**event.attributes,
|
**(event.attributes or {}),
|
||||||
},
|
},
|
||||||
timestamp=timestamp_ns,
|
timestamp=timestamp_ns,
|
||||||
)
|
)
|
||||||
|
@ -154,6 +154,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||||
|
|
||||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||||
|
assert self.meter is not None
|
||||||
if name not in _GLOBAL_STORAGE["counters"]:
|
if name not in _GLOBAL_STORAGE["counters"]:
|
||||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -163,6 +164,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
return _GLOBAL_STORAGE["counters"][name]
|
return _GLOBAL_STORAGE["counters"][name]
|
||||||
|
|
||||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||||
|
assert self.meter is not None
|
||||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -182,6 +184,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
up_down_counter.add(event.value, attributes=event.attributes)
|
up_down_counter.add(event.value, attributes=event.attributes)
|
||||||
|
|
||||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||||
|
assert self.meter is not None
|
||||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
|
@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_subprocess.Popen = popen_not_allowed
|
_subprocess.Popen = popen_not_allowed # type: ignore
|
||||||
|
|
||||||
|
|
||||||
import atexit as _atexit
|
import atexit as _atexit
|
||||||
|
@ -104,7 +104,7 @@ def _open_connections():
|
||||||
return _NETWORK_CONNECTIONS
|
return _NETWORK_CONNECTIONS
|
||||||
|
|
||||||
|
|
||||||
_builtins._open_connections = _open_connections
|
_builtins._open_connections = _open_connections # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@_atexit.register
|
@_atexit.register
|
||||||
|
|
|
@ -161,9 +161,9 @@ _set_seeds()\
|
||||||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||||
image_data = response["image_data"]
|
image_data = response["image_data"]
|
||||||
# Convert the base64 string to a bytes object
|
# Convert the base64 string to a bytes object
|
||||||
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||||
# Create a list of PIL images from the bytes objects
|
# Create a list of PIL images from the bytes objects
|
||||||
images = [Image.open(BytesIO(img)) for img in images]
|
images = [Image.open(BytesIO(img)) for img in images_raw]
|
||||||
# Create a list of image paths
|
# Create a list of image paths
|
||||||
image_paths = []
|
image_paths = []
|
||||||
for i, img in enumerate(images):
|
for i, img in enumerate(images):
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import Api
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
|
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
|
||||||
from .memory import MemoryToolRuntimeImpl
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import TypeAdapter
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
|
@ -23,6 +24,7 @@ from llama_stack.apis.tools import (
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
RAGQueryResult,
|
RAGQueryResult,
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
|
Tool,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
|
@ -62,6 +64,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self,
|
self,
|
||||||
documents: List[RAGDocument],
|
documents: List[RAGDocument],
|
||||||
|
@ -121,11 +129,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
return RAGQueryResult(content=None)
|
return RAGQueryResult(content=None)
|
||||||
|
|
||||||
# sort by score
|
# sort by score
|
||||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
||||||
chunks = chunks[: query_config.max_chunks]
|
chunks = chunks[: query_config.max_chunks]
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = [
|
picked: list[InterleavedContentItem] = [
|
||||||
TextContentItem(
|
TextContentItem(
|
||||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,11 +15,13 @@ import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
chunk_by_index: Dict[int, str]
|
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
|
|
||||||
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
self.chunk_by_index = {}
|
self.chunk_by_index: dict[int, Chunk] = {}
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
instance = cls(dimension, kvstore, bank_id)
|
instance = cls(dimension, kvstore, bank_id)
|
||||||
await instance.initialize()
|
await instance.initialize()
|
||||||
return instance
|
return instance
|
||||||
|
@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.cache = {}
|
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||||
self.kvstore = None
|
self.kvstore: KVStore | None = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
vector_db: VectorDB,
|
vector_db: VectorDB,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert self.kvstore is not None
|
||||||
|
|
||||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
|
@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
return [i.vector_db for i in self.cache.values()]
|
return [i.vector_db for i in self.cache.values()]
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
assert self.kvstore is not None
|
||||||
|
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||||
return
|
return
|
||||||
|
|
|
@ -15,9 +15,10 @@ import numpy as np
|
||||||
import sqlite_vec
|
import sqlite_vec
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||||
"""
|
"""
|
||||||
|
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||||
|
|
||||||
cur = self.connection.cursor()
|
cur = self.connection.cursor()
|
||||||
try:
|
try:
|
||||||
# Start transaction
|
# Start transaction
|
||||||
|
@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
metadata_data = [
|
metadata_data = [
|
||||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||||
for chunk in batch_chunks
|
for chunk in batch_chunks
|
||||||
|
if isinstance(chunk.content, str)
|
||||||
]
|
]
|
||||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||||
cur.executemany(
|
cur.executemany(
|
||||||
|
@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
embedding_data = [
|
embedding_data = [
|
||||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||||
|
if isinstance(chunk.content, str)
|
||||||
]
|
]
|
||||||
# Insert embeddings in batch
|
# Insert embeddings in batch
|
||||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||||
|
@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, inference_api: Api.inference) -> None:
|
def __init__(self, config, inference_api: Inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
|
||||||
class TelemetryDatasetMixin:
|
class TelemetryDatasetMixin:
|
||||||
"""Mixin class that provides dataset-related functionality for telemetry providers."""
|
"""Mixin class that provides dataset-related functionality for telemetry providers."""
|
||||||
|
|
||||||
datasetio_api: DatasetIO
|
datasetio_api: DatasetIO | None
|
||||||
|
|
||||||
async def save_spans_to_dataset(
|
async def save_spans_to_dataset(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -235,16 +235,6 @@ exclude = [
|
||||||
"^llama_stack/providers/inline/scoring/basic/",
|
"^llama_stack/providers/inline/scoring/basic/",
|
||||||
"^llama_stack/providers/inline/scoring/braintrust/",
|
"^llama_stack/providers/inline/scoring/braintrust/",
|
||||||
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
"^llama_stack/providers/inline/scoring/llm_as_judge/",
|
||||||
"^llama_stack/providers/inline/telemetry/meta_reference/console_span_processor\\.py$",
|
|
||||||
"^llama_stack/providers/inline/telemetry/meta_reference/telemetry\\.py$",
|
|
||||||
"^llama_stack/providers/inline/telemetry/sample/",
|
|
||||||
"^llama_stack/providers/inline/tool_runtime/code_interpreter/",
|
|
||||||
"^llama_stack/providers/inline/tool_runtime/rag/",
|
|
||||||
"^llama_stack/providers/inline/vector_io/chroma/",
|
|
||||||
"^llama_stack/providers/inline/vector_io/faiss/",
|
|
||||||
"^llama_stack/providers/inline/vector_io/milvus/",
|
|
||||||
"^llama_stack/providers/inline/vector_io/qdrant/",
|
|
||||||
"^llama_stack/providers/inline/vector_io/sqlite_vec/",
|
|
||||||
"^llama_stack/providers/remote/agents/sample/",
|
"^llama_stack/providers/remote/agents/sample/",
|
||||||
"^llama_stack/providers/remote/datasetio/huggingface/",
|
"^llama_stack/providers/remote/datasetio/huggingface/",
|
||||||
"^llama_stack/providers/remote/inference/anthropic/",
|
"^llama_stack/providers/remote/inference/anthropic/",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue