forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR fixes some of the issues with our telemetry setup to enable logs to be delivered to opentelemetry and jaeger. Main fixes 1) Updates the open telemetry provider to use the latest oltp exports instead of deprected ones. 2) Adds a tracing middleware, which injects traces into each HTTP request that the server recieves and this is going to be the root trace. Previously, we did this in the create_dynamic_route method, which is actually not the actual exectuion flow, but more of a config and this causes the traces to end prematurely. Through middleware, we plugin the trace start and end at the right location. 3) We manage our own methods to create traces and spans and this does not fit well with Opentelemetry SDK since it does not support provide a way to take in traces and spans that are already created. it expects us to use the SDK to create them. For now, I have a hacky approach of just maintaining a map from our internal telemetry objects to the open telemetry specfic ones. This is not the ideal solution. I will explore other ways to get around this issue. for now, to have something that works, i am going to keep this as is. Addresses: #509
209 lines
6.6 KiB
Python
209 lines
6.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import logging
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import faiss
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.apis.memory import * # noqa: F403
|
|
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
ALL_MINILM_L6_V2_DIMENSION,
|
|
BankWithIndex,
|
|
EmbeddingIndex,
|
|
)
|
|
from llama_stack.providers.utils.telemetry import tracing
|
|
|
|
from .config import FaissImplConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
|
|
|
|
|
class FaissIndex(EmbeddingIndex):
|
|
id_by_index: Dict[int, str]
|
|
chunk_by_index: Dict[int, str]
|
|
|
|
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
self.id_by_index = {}
|
|
self.chunk_by_index = {}
|
|
self.kvstore = kvstore
|
|
self.bank_id = bank_id
|
|
|
|
@classmethod
|
|
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
|
instance = cls(dimension, kvstore, bank_id)
|
|
await instance.initialize()
|
|
return instance
|
|
|
|
async def initialize(self) -> None:
|
|
if not self.kvstore:
|
|
return
|
|
|
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
|
stored_data = await self.kvstore.get(index_key)
|
|
|
|
if stored_data:
|
|
data = json.loads(stored_data)
|
|
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
|
|
self.chunk_by_index = {
|
|
int(k): Chunk.model_validate_json(v)
|
|
for k, v in data["chunk_by_index"].items()
|
|
}
|
|
|
|
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
|
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
|
|
|
|
async def _save_index(self):
|
|
if not self.kvstore or not self.bank_id:
|
|
return
|
|
|
|
np_index = faiss.serialize_index(self.index)
|
|
buffer = io.BytesIO()
|
|
np.savetxt(buffer, np_index)
|
|
data = {
|
|
"id_by_index": self.id_by_index,
|
|
"chunk_by_index": {
|
|
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
|
|
},
|
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
|
}
|
|
|
|
index_key = f"faiss_index:v1::{self.bank_id}"
|
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
|
|
|
async def delete(self):
|
|
if not self.kvstore or not self.bank_id:
|
|
return
|
|
|
|
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
|
|
|
@tracing.span(name="add_chunks")
|
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
|
indexlen = len(self.id_by_index)
|
|
for i, chunk in enumerate(chunks):
|
|
self.chunk_by_index[indexlen + i] = chunk
|
|
self.id_by_index[indexlen + i] = chunk.document_id
|
|
|
|
self.index.add(np.array(embeddings).astype(np.float32))
|
|
|
|
# Save updated index
|
|
await self._save_index()
|
|
|
|
async def query(
|
|
self, embedding: NDArray, k: int, score_threshold: float
|
|
) -> QueryDocumentsResponse:
|
|
distances, indices = self.index.search(
|
|
embedding.reshape(1, -1).astype(np.float32), k
|
|
)
|
|
|
|
chunks = []
|
|
scores = []
|
|
for d, i in zip(distances[0], indices[0]):
|
|
if i < 0:
|
|
continue
|
|
chunks.append(self.chunk_by_index[int(i)])
|
|
scores.append(1.0 / float(d))
|
|
|
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
|
|
|
|
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|
def __init__(self, config: FaissImplConfig) -> None:
|
|
self.config = config
|
|
self.cache = {}
|
|
self.kvstore = None
|
|
|
|
async def initialize(self) -> None:
|
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
|
# Load existing banks from kvstore
|
|
start_key = MEMORY_BANKS_PREFIX
|
|
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
|
|
stored_banks = await self.kvstore.range(start_key, end_key)
|
|
|
|
for bank_data in stored_banks:
|
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
|
index = BankWithIndex(
|
|
bank=bank,
|
|
index=await FaissIndex.create(
|
|
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
|
),
|
|
)
|
|
self.cache[bank.identifier] = index
|
|
|
|
async def shutdown(self) -> None:
|
|
# Cleanup if needed
|
|
pass
|
|
|
|
async def register_memory_bank(
|
|
self,
|
|
memory_bank: MemoryBank,
|
|
) -> None:
|
|
assert (
|
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
|
), f"Only vector banks are supported {memory_bank.type}"
|
|
|
|
# Store in kvstore
|
|
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
|
await self.kvstore.set(
|
|
key=key,
|
|
value=memory_bank.model_dump_json(),
|
|
)
|
|
|
|
# Store in cache
|
|
index = BankWithIndex(
|
|
bank=memory_bank,
|
|
index=await FaissIndex.create(
|
|
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
|
),
|
|
)
|
|
self.cache[memory_bank.identifier] = index
|
|
|
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
|
return [i.bank for i in self.cache.values()]
|
|
|
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
|
await self.cache[memory_bank_id].index.delete()
|
|
del self.cache[memory_bank_id]
|
|
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
|
|
|
async def insert_documents(
|
|
self,
|
|
bank_id: str,
|
|
documents: List[MemoryBankDocument],
|
|
ttl_seconds: Optional[int] = None,
|
|
) -> None:
|
|
index = self.cache.get(bank_id)
|
|
if index is None:
|
|
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
|
|
|
|
await index.insert_documents(documents)
|
|
|
|
async def query_documents(
|
|
self,
|
|
bank_id: str,
|
|
query: InterleavedTextMedia,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> QueryDocumentsResponse:
|
|
index = self.cache.get(bank_id)
|
|
if index is None:
|
|
raise ValueError(f"Bank {bank_id} not found")
|
|
|
|
return await index.query_documents(query, params)
|