forked from phoenix-oss/llama-stack-mirror
## What does this PR do? This is a long-pending change and particularly important to get done now. Specifically: - we cannot "localize" (aka download) any URLs from media attachments anywhere near our modeling code. it must be done within llama-stack. - `PIL.Image` is infesting all our APIs via `ImageMedia -> InterleavedTextMedia` and that cannot be right at all. Anything in the API surface must be "naturally serializable". We need a standard `{ type: "image", image_url: "<...>" }` which is more extensible - `UserMessage`, `SystemMessage`, etc. are moved completely to llama-stack from the llama-models repository. See https://github.com/meta-llama/llama-models/pull/244 for the corresponding PR in llama-models. ## Test Plan ```bash cd llama_stack/providers/tests pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py pytest -s -v -k chroma memory/test_memory.py \ --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar pytest -s -v -k fireworks agents/test_agents.py \ --safety-shield=meta-llama/Llama-Guard-3-8B \ --inference-model=meta-llama/Llama-3.1-8B-Instruct ``` Updated the client sdk (see PR ...), installed the SDK in the same environment and then ran the SDK tests: ```bash cd tests/client-sdk LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py # this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py ```
219 lines
7.1 KiB
Python
219 lines
7.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import logging
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import faiss
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.apis.memory import * # noqa: F403
|
|
from llama_stack.apis.inference import InterleavedContent
|
|
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
BankWithIndex,
|
|
EmbeddingIndex,
|
|
)
|
|
|
|
from .config import FaissImplConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
|
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
|
|
|
|
|
class FaissIndex(EmbeddingIndex):
|
|
id_by_index: Dict[int, str]
|
|
chunk_by_index: Dict[int, str]
|
|
|
|
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
self.id_by_index = {}
|
|
self.chunk_by_index = {}
|
|
self.kvstore = kvstore
|
|
self.bank_id = bank_id
|
|
|
|
@classmethod
|
|
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
|
instance = cls(dimension, kvstore, bank_id)
|
|
await instance.initialize()
|
|
return instance
|
|
|
|
async def initialize(self) -> None:
|
|
if not self.kvstore:
|
|
return
|
|
|
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
|
stored_data = await self.kvstore.get(index_key)
|
|
|
|
if stored_data:
|
|
data = json.loads(stored_data)
|
|
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
|
|
self.chunk_by_index = {
|
|
int(k): Chunk.model_validate_json(v)
|
|
for k, v in data["chunk_by_index"].items()
|
|
}
|
|
|
|
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
|
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
|
|
|
|
async def _save_index(self):
|
|
if not self.kvstore or not self.bank_id:
|
|
return
|
|
|
|
np_index = faiss.serialize_index(self.index)
|
|
buffer = io.BytesIO()
|
|
np.savetxt(buffer, np_index)
|
|
data = {
|
|
"id_by_index": self.id_by_index,
|
|
"chunk_by_index": {
|
|
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
|
|
},
|
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
|
}
|
|
|
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
|
|
|
async def delete(self):
|
|
if not self.kvstore or not self.bank_id:
|
|
return
|
|
|
|
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
|
|
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
|
# Add dimension check
|
|
embedding_dim = (
|
|
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
|
)
|
|
if embedding_dim != self.index.d:
|
|
raise ValueError(
|
|
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
|
)
|
|
|
|
indexlen = len(self.id_by_index)
|
|
for i, chunk in enumerate(chunks):
|
|
self.chunk_by_index[indexlen + i] = chunk
|
|
self.id_by_index[indexlen + i] = chunk.document_id
|
|
|
|
self.index.add(np.array(embeddings).astype(np.float32))
|
|
|
|
# Save updated index
|
|
await self._save_index()
|
|
|
|
async def query(
|
|
self, embedding: NDArray, k: int, score_threshold: float
|
|
) -> QueryDocumentsResponse:
|
|
distances, indices = self.index.search(
|
|
embedding.reshape(1, -1).astype(np.float32), k
|
|
)
|
|
|
|
chunks = []
|
|
scores = []
|
|
for d, i in zip(distances[0], indices[0]):
|
|
if i < 0:
|
|
continue
|
|
chunks.append(self.chunk_by_index[int(i)])
|
|
scores.append(1.0 / float(d))
|
|
|
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
|
|
|
|
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
|
self.config = config
|
|
self.inference_api = inference_api
|
|
self.cache = {}
|
|
self.kvstore = None
|
|
|
|
async def initialize(self) -> None:
|
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
|
# Load existing banks from kvstore
|
|
start_key = MEMORY_BANKS_PREFIX
|
|
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
|
|
stored_banks = await self.kvstore.range(start_key, end_key)
|
|
|
|
for bank_data in stored_banks:
|
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
|
index = BankWithIndex(
|
|
bank,
|
|
await FaissIndex.create(
|
|
bank.embedding_dimension, self.kvstore, bank.identifier
|
|
),
|
|
self.inference_api,
|
|
)
|
|
self.cache[bank.identifier] = index
|
|
|
|
async def shutdown(self) -> None:
|
|
# Cleanup if needed
|
|
pass
|
|
|
|
async def register_memory_bank(
|
|
self,
|
|
memory_bank: MemoryBank,
|
|
) -> None:
|
|
assert (
|
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
|
), f"Only vector banks are supported {memory_bank.type}"
|
|
|
|
# Store in kvstore
|
|
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
|
await self.kvstore.set(
|
|
key=key,
|
|
value=memory_bank.model_dump_json(),
|
|
)
|
|
|
|
# Store in cache
|
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
|
memory_bank,
|
|
await FaissIndex.create(
|
|
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
|
),
|
|
self.inference_api,
|
|
)
|
|
|
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
|
return [i.bank for i in self.cache.values()]
|
|
|
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
|
await self.cache[memory_bank_id].index.delete()
|
|
del self.cache[memory_bank_id]
|
|
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
|
|
|
|
async def insert_documents(
|
|
self,
|
|
bank_id: str,
|
|
documents: List[MemoryBankDocument],
|
|
ttl_seconds: Optional[int] = None,
|
|
) -> None:
|
|
index = self.cache.get(bank_id)
|
|
if index is None:
|
|
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
|
|
|
|
await index.insert_documents(documents)
|
|
|
|
async def query_documents(
|
|
self,
|
|
bank_id: str,
|
|
query: InterleavedContent,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> QueryDocumentsResponse:
|
|
index = self.cache.get(bank_id)
|
|
if index is None:
|
|
raise ValueError(f"Bank {bank_id} not found")
|
|
|
|
return await index.query_documents(query, params)
|