feat: Enable ingestion of custom embeddings

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-05-29 20:58:41 -04:00 committed by Francisco Arceo
parent 2603f10f95
commit 73456878e5
8 changed files with 224 additions and 15 deletions

View file

@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
@json_schema_type
@ -50,7 +58,10 @@ class VectorIO(Protocol):
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks.
"""
...

View file

@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:

View file

@ -199,11 +199,16 @@ class VectorDBWithIndex:
self,
chunks: list[Chunk],
) -> None:
embeddings_response = await self.inference_api.embeddings(
self.vector_db.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
chunks_to_embed = [c for c in chunks if c.embedding is None]
if chunks_to_embed:
resp = await self.inference_api.embeddings(
self.vector_db.embedding_model,
[c.content for c in chunks_to_embed],
)
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
c.embedding = embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_chunks(