mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 11:48:06 +00:00
Merge branch 'main' into delete-embeddings
This commit is contained in:
commit
38cd41dd21
50 changed files with 1752 additions and 641 deletions
|
|
@ -115,6 +115,19 @@ class ProviderSpec(BaseModel):
|
|||
description="If this provider is deprecated and does NOT work, specify the error message here",
|
||||
)
|
||||
|
||||
module: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
|
||||
Example: `module: ramalama_stack`
|
||||
""",
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: list[str] = Field(default_factory=list)
|
||||
|
||||
|
|
@ -135,7 +148,7 @@ class AdapterSpec(BaseModel):
|
|||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
|
|
@ -173,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
|
@ -223,9 +229,7 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
return self.adapter.module
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
|
|
@ -246,6 +250,7 @@ def remote_provider_spec(
|
|||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
|
|
|
|||
|
|
@ -57,12 +57,15 @@ class ChromaIndex(EmbeddingIndex):
|
|||
self.collection = collection
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
|
||||
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)]
|
||||
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
|
|
@ -140,9 +143,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.client = None
|
||||
self.cache = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
self.vector_db_store = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
|
|
@ -175,6 +181,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
)
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
|
|
@ -185,6 +195,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
|
|
@ -196,6 +208,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
) -> QueryChunksResponse:
|
||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||
|
||||
if index is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue