llama-stack-mirror/llama_stack/providers/remote/memory/chroma/chroma.py
Xi Yan 3c72c034e6
[remove import *] clean up import *'s (#689)
# What does this PR do?

- as title, cleaning up `import *`'s
- upgrade tests to make them more robust to bad model outputs
- remove import *'s in llama_stack/apis/* (skip __init__ modules)
<img width="465" alt="image"
src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2"
/>

- run `sh run_openapi_generator.sh`, no types gets affected

## Test Plan

### Providers Tests

**agents**
```
pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8
```

**inference**
```bash
# meta-reference
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

# together
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py 
```

**safety**
```
pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B
```

**memory**
```
pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
```

**scoring**
```
pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct
pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
```


**datasetio**
```
pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py
pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py
```


**eval**
```
pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py
pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py
```

### Client-SDK Tests
```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk
```

### llama-stack-apps
```
PORT=5000
LOCALHOST=localhost

python -m examples.agents.hello $LOCALHOST $PORT
python -m examples.agents.inflation $LOCALHOST $PORT
python -m examples.agents.podcast_transcript $LOCALHOST $PORT
python -m examples.agents.rag_as_attachments $LOCALHOST $PORT
python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT
python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT
python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT

# Vision model
python -m examples.interior_design_assistant.app
python -m examples.agent_store.app $LOCALHOST $PORT
```

### CLI
```
which llama
llama model prompt-format -m Llama3.2-11B-Vision-Instruct
llama model list
llama stack list-apis
llama stack list-providers inference

llama stack build --template ollama --image-type conda
```

### Distributions Tests
**ollama**
```
llama stack build --template ollama --image-type conda
ollama run llama3.2:1b-instruct-fp16
llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
```

**fireworks**
```
llama stack build --template fireworks --image-type conda
llama stack run ./llama_stack/templates/fireworks/run.yaml
```

**together**
```
llama stack build --template together --image-type conda
llama stack run ./llama_stack/templates/together/run.yaml
```

**tgi**
```
llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-12-27 15:45:44 -08:00

182 lines
6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
from typing import List, Optional, Union
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__)
ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient]
# this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result):
if asyncio.iscoroutine(result):
return await result
return result
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: ChromaClientType, collection):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
)
)
distances = results["distances"][0]
documents = results["documents"][0]
chunks = []
scores = []
for dist, doc in zip(distances, documents):
try:
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
log.exception(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
self.cache = {}
async def initialize(self) -> None:
if isinstance(self.config, ChromaRemoteImplConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.client = await chromadb.AsyncHttpClient(
host=parsed.hostname, port=parsed.port
)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await maybe_await(
self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(bank_id))
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
return index