llama-stack-mirror/tests/client-sdk/memory/test_memory.py
Xi Yan 78e2bfbe7a
[tests] add client-sdk pytests & delete client.py (#638)
# What does this PR do?

**Why**
- Clean up examples which we will not maintain; reduce the surface area
to the minimal showcases

**What**
- Delete `client.py` in /apis/*
- Move all scripts to unit tests
  - SDK sync in the future will just require running pytests

**Side notes**
- `bwrap` not available on Mac so code_interpreter will not work

## Test Plan

```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk
```
<img width="725" alt="image"
src="https://github.com/user-attachments/assets/36bfe537-628d-43c3-8479-dcfcfe2e4035"
/>


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-12-16 12:04:56 -08:00

72 lines
2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client.types.memory_insert_params import Document
def test_memory_bank(llama_stack_client):
providers = llama_stack_client.providers.list()
if "memory" not in providers:
pytest.skip("No memory provider available")
# get memory provider id
assert len(providers["memory"]) > 0
memory_provider_id = providers["memory"][0].provider_id
memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register(
memory_bank_id=memory_bank_id,
params={
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id=memory_provider_id,
)
# list to check memory bank is successfully registered
available_memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
]
assert memory_bank_id in available_memory_banks
# add documents to memory bank
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
llama_stack_client.memory.insert(
bank_id=memory_bank_id,
documents=documents,
)
# query documents
response = llama_stack_client.memory.query(
bank_id=memory_bank_id,
query=[
"How do I use lora",
],
)
assert len(response.chunks) > 0
assert len(response.chunks) == len(response.scores)
contents = [chunk.content for chunk in response.chunks]
assert "lora" in contents[0].lower()