feat: Enable ingestion of precomputed embeddings (#2317)

This commit is contained in:
Francisco Arceo 2025-05-31 04:03:37 -06:00 committed by GitHub
parent 31ce208bda
commit f328436831
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 366 additions and 15 deletions

View file

@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
"""
content: InterleavedContent
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
@json_schema_type
@ -50,7 +58,10 @@ class VectorIO(Protocol):
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
If `embedding` is not provided, it will be computed later.
:param ttl_seconds: The time to live of the chunks.
"""
...

View file

@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
tokens += metadata["token_count"]
tokens += metadata.get("token_count", 0)
tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context:

View file

@ -171,6 +171,22 @@ def make_overlapped_chunks(
return chunks
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
"""Helper method to validate embedding format and dimensions"""
if not isinstance(embedding, (list | np.ndarray)):
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
if isinstance(embedding, np.ndarray):
if not np.issubdtype(embedding.dtype, np.number):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
else:
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
raise ValueError(f"Embedding at index {index} contains non-numeric values")
if len(embedding) != expected_dimension:
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
class EmbeddingIndex(ABC):
@abstractmethod
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
@ -199,11 +215,22 @@ class VectorDBWithIndex:
self,
chunks: list[Chunk],
) -> None:
embeddings_response = await self.inference_api.embeddings(
self.vector_db.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
chunks_to_embed = []
for i, c in enumerate(chunks):
if c.embedding is None:
chunks_to_embed.append(c)
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
if chunks_to_embed:
resp = await self.inference_api.embeddings(
self.vector_db.embedding_model,
[c.content for c in chunks_to_embed],
)
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
c.embedding = embedding
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)
async def query_chunks(