forked from phoenix-oss/llama-stack-mirror
chore: mypy violations cleanup for inline::{telemetry,tool_runtime,vector_io} (#1711)
# What does this PR do? Clean up mypy violations for inline::{telemetry,tool_runtime,vector_io}. This also makes API accept a tool call result without any content (like RAG tool already may produce). Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
355134f51d
commit
515c16e352
15 changed files with 51 additions and 44 deletions
|
@ -15,11 +15,13 @@ import faiss
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
|||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
chunk_by_index: Dict[int, str]
|
||||
|
||||
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.chunk_by_index = {}
|
||||
self.chunk_by_index: dict[int, Chunk] = {}
|
||||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
await instance.initialize()
|
||||
return instance
|
||||
|
@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
|
@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
return [i.vector_db for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
|
|
@ -15,9 +15,10 @@ import numpy as np
|
|||
import sqlite_vec
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
"""
|
||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Start transaction
|
||||
|
@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
|
@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
embedding_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Api.inference) -> None:
|
||||
def __init__(self, config, inference_api: Inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue