forked from phoenix-oss/llama-stack-mirror
feat: Enable ingestion of precomputed embeddings (#2317)
This commit is contained in:
parent
31ce208bda
commit
f328436831
9 changed files with 366 additions and 15 deletions
|
@ -19,8 +19,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param metadata: Metadata associated with the chunk, such as document ID, source, or other relevant information.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -50,7 +58,10 @@ class VectorIO(Protocol):
|
|||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert.
|
||||
:param chunks: The chunks to insert. Each `Chunk` should contain content which can be interleaved text, images, or other types.
|
||||
`metadata`: `dict[str, Any]` and `embedding`: `List[float]` are optional.
|
||||
If `metadata` is provided, you configure how Llama Stack formats the chunk during generation.
|
||||
If `embedding` is not provided, it will be computed later.
|
||||
:param ttl_seconds: The time to live of the chunks.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -146,7 +146,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
]
|
||||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
tokens += metadata["token_count"]
|
||||
tokens += metadata.get("token_count", 0)
|
||||
tokens += metadata.get("metadata_token_count", 0)
|
||||
|
||||
if tokens > query_config.max_tokens_in_context:
|
||||
|
|
|
@ -171,6 +171,22 @@ def make_overlapped_chunks(
|
|||
return chunks
|
||||
|
||||
|
||||
def _validate_embedding(embedding: NDArray, index: int, expected_dimension: int):
|
||||
"""Helper method to validate embedding format and dimensions"""
|
||||
if not isinstance(embedding, (list | np.ndarray)):
|
||||
raise ValueError(f"Embedding at index {index} must be a list or numpy array, got {type(embedding)}")
|
||||
|
||||
if isinstance(embedding, np.ndarray):
|
||||
if not np.issubdtype(embedding.dtype, np.number):
|
||||
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||
else:
|
||||
if not all(isinstance(e, (float | int | np.number)) for e in embedding):
|
||||
raise ValueError(f"Embedding at index {index} contains non-numeric values")
|
||||
|
||||
if len(embedding) != expected_dimension:
|
||||
raise ValueError(f"Embedding at index {index} has dimension {len(embedding)}, expected {expected_dimension}")
|
||||
|
||||
|
||||
class EmbeddingIndex(ABC):
|
||||
@abstractmethod
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
|
@ -199,11 +215,22 @@ class VectorDBWithIndex:
|
|||
self,
|
||||
chunks: list[Chunk],
|
||||
) -> None:
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.vector_db.embedding_model, [x.content for x in chunks]
|
||||
)
|
||||
embeddings = np.array(embeddings_response.embeddings)
|
||||
chunks_to_embed = []
|
||||
for i, c in enumerate(chunks):
|
||||
if c.embedding is None:
|
||||
chunks_to_embed.append(c)
|
||||
else:
|
||||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
resp = await self.inference_api.embeddings(
|
||||
self.vector_db.embedding_model,
|
||||
[c.content for c in chunks_to_embed],
|
||||
)
|
||||
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
|
||||
c.embedding = embedding
|
||||
|
||||
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
async def query_chunks(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue