forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -11,9 +11,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
|||
from .config import ChromaRemoteImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .chroma import ChromaVectorIOAdapter
|
||||
|
||||
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
@ -42,9 +42,9 @@ class ChromaIndex(EmbeddingIndex):
|
|||
self.collection = collection
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)]
|
||||
await maybe_await(
|
||||
|
@ -55,9 +55,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
)
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryChunksResponse:
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
|
@ -109,9 +107,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
if parsed.path and parsed.path != "/":
|
||||
raise ValueError("URL should not contain a path")
|
||||
|
||||
self.client = await chromadb.AsyncHttpClient(
|
||||
host=parsed.hostname, port=parsed.port
|
||||
)
|
||||
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
||||
else:
|
||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
|
@ -157,9 +153,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(
|
||||
self, vector_db_id: str
|
||||
) -> VectorDBWithIndex:
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
@ -169,8 +163,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
collection = await maybe_await(self.client.get_collection(vector_db_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
index = VectorDBWithIndex(
|
||||
vector_db, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
|
|
|
@ -71,9 +71,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
@ -94,9 +94,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
)
|
||||
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryChunksResponse:
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
|
@ -166,9 +164,7 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
upsert_models(self.cursor, [(vector_db.identifier, vector_db)])
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db, index, self.inference_api
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
@ -192,15 +188,11 @@ class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(
|
||||
self, vector_db_id: str
|
||||
) -> VectorDBWithIndex:
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(
|
||||
vector_db, index, self.inference_api
|
||||
)
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
|
|
@ -43,16 +43,14 @@ class QdrantIndex(EmbeddingIndex):
|
|||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
if not await self.client.collection_exists(self.collection_name):
|
||||
await self.client.create_collection(
|
||||
self.collection_name,
|
||||
vectors_config=models.VectorParams(
|
||||
size=len(embeddings[0]), distance=models.Distance.COSINE
|
||||
),
|
||||
vectors_config=models.VectorParams(size=len(embeddings[0]), distance=models.Distance.COSINE),
|
||||
)
|
||||
|
||||
points = []
|
||||
|
@ -62,16 +60,13 @@ class QdrantIndex(EmbeddingIndex):
|
|||
PointStruct(
|
||||
id=convert_id(chunk_id),
|
||||
vector=embedding,
|
||||
payload={"chunk_content": chunk.model_dump()}
|
||||
| {CHUNK_ID_KEY: chunk_id},
|
||||
payload={"chunk_content": chunk.model_dump()} | {CHUNK_ID_KEY: chunk_id},
|
||||
)
|
||||
)
|
||||
|
||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryChunksResponse:
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
results = (
|
||||
await self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
|
@ -124,9 +119,7 @@ class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def _get_and_cache_vector_db_index(
|
||||
self, vector_db_id: str
|
||||
) -> Optional[VectorDBWithIndex]:
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
|
|
@ -35,9 +35,9 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
self.collection_name = collection_name
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(
|
||||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
data_objects = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
@ -56,9 +56,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
# TODO: make this async friendly
|
||||
collection.data.insert_many(data_objects)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryChunksResponse:
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
|
||||
results = collection.query.near_vector(
|
||||
|
@ -85,9 +83,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
|
||||
async def delete(self, chunk_ids: List[str]) -> None:
|
||||
collection = self.client.collections.get(self.collection_name)
|
||||
collection.data.delete_many(
|
||||
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||
)
|
||||
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
|
||||
|
||||
|
||||
class WeaviateMemoryAdapter(
|
||||
|
@ -149,9 +145,7 @@ class WeaviateMemoryAdapter(
|
|||
self.inference_api,
|
||||
)
|
||||
|
||||
async def _get_and_cache_vector_db_index(
|
||||
self, vector_db_id: str
|
||||
) -> Optional[VectorDBWithIndex]:
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]:
|
||||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue