llama-stack/llama_stack/providers/inline/memory/faiss/faiss.py
Xi Yan 3c72c034e6
[remove import *] clean up import *'s (#689)
# What does this PR do?

- as title, cleaning up `import *`'s
- upgrade tests to make them more robust to bad model outputs
- remove import *'s in llama_stack/apis/* (skip __init__ modules)
<img width="465" alt="image"
src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2"
/>

- run `sh run_openapi_generator.sh`, no types gets affected

## Test Plan

### Providers Tests

**agents**
```
pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8
```

**inference**
```bash
# meta-reference
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

# together
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py 
```

**safety**
```
pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B
```

**memory**
```
pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
```

**scoring**
```
pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct
pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
```


**datasetio**
```
pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py
pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py
```


**eval**
```
pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py
pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py
```

### Client-SDK Tests
```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk
```

### llama-stack-apps
```
PORT=5000
LOCALHOST=localhost

python -m examples.agents.hello $LOCALHOST $PORT
python -m examples.agents.inflation $LOCALHOST $PORT
python -m examples.agents.podcast_transcript $LOCALHOST $PORT
python -m examples.agents.rag_as_attachments $LOCALHOST $PORT
python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT
python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT
python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT

# Vision model
python -m examples.interior_design_assistant.app
python -m examples.agent_store.app $LOCALHOST $PORT
```

### CLI
```
which llama
llama model prompt-format -m Llama3.2-11B-Vision-Instruct
llama model list
llama stack list-apis
llama stack list-providers inference

llama stack build --template ollama --image-type conda
```

### Distributions Tests
**ollama**
```
llama stack build --template ollama --image-type conda
ollama run llama3.2:1b-instruct-fp16
llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
```

**fireworks**
```
llama stack build --template fireworks --image-type conda
llama stack run ./llama_stack/templates/fireworks/run.yaml
```

**together**
```
llama stack build --template together --image-type conda
llama stack run ./llama_stack/templates/together/run.yaml
```

**tgi**
```
llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-12-27 15:45:44 -08:00

222 lines
7.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
import json
import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import FaissImplConfig
logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
FAISS_INDEX_PREFIX = "faiss_index:v2::"
class FaissIndex(EmbeddingIndex):
id_by_index: Dict[int, str]
chunk_by_index: Dict[int, str]
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
self.index = faiss.IndexFlatL2(dimension)
self.id_by_index = {}
self.chunk_by_index = {}
self.kvstore = kvstore
self.bank_id = bank_id
@classmethod
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
instance = cls(dimension, kvstore, bank_id)
await instance.initialize()
return instance
async def initialize(self) -> None:
if not self.kvstore:
return
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
stored_data = await self.kvstore.get(index_key)
if stored_data:
data = json.loads(stored_data)
self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()}
self.chunk_by_index = {
int(k): Chunk.model_validate_json(v)
for k, v in data["chunk_by_index"].items()
}
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
self.index = faiss.deserialize_index(np.loadtxt(buffer, dtype=np.uint8))
async def _save_index(self):
if not self.kvstore or not self.bank_id:
return
np_index = faiss.serialize_index(self.index)
buffer = io.BytesIO()
np.savetxt(buffer, np_index)
data = {
"id_by_index": self.id_by_index,
"chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
}
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self):
if not self.kvstore or not self.bank_id:
return
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk
self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32))
# Save updated index
await self._save_index()
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)
chunks = []
scores = []
for d, i in zip(distances[0], indices[0]):
if i < 0:
continue
chunks.append(self.chunk_by_index[int(i)])
scores.append(1.0 / float(d))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cache = {}
self.kvstore = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing banks from kvstore
start_key = MEMORY_BANKS_PREFIX
end_key = f"{MEMORY_BANKS_PREFIX}\xff"
stored_banks = await self.kvstore.range(start_key, end_key)
for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex(
bank,
await FaissIndex.create(
bank.embedding_dimension, self.kvstore, bank.identifier
),
self.inference_api,
)
self.cache[bank.identifier] = index
async def shutdown(self) -> None:
# Cleanup if needed
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
# Store in kvstore
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set(
key=key,
value=memory_bank.model_dump_json(),
)
# Store in cache
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
await FaissIndex.create(
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
),
self.inference_api,
)
async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}")
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id)
if index is None:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)