[bugfix] fix client-sdk tests for v1 (#777)

# What does this PR do?

- as title, as API have been updated

## Test Plan

```
LLAMA_STACK_BASE_URL="http://localhost:5000" pytest -v tests/client-sdk/
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-15 16:06:57 -08:00 committed by GitHub
parent 8fd9bcb8cd
commit 965644ce68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 19 additions and 18 deletions

View file

@ -39,7 +39,6 @@ def text_model_id(llama_stack_client):
for model in llama_stack_client.models.list().data
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
]
print(available_models)
assert len(available_models) > 0
return available_models[0]
@ -268,10 +267,6 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
stream=True,
)
tool_invocation_content = extract_tool_invocation_content(response)
print(
"!!!!tool_invocation_content",
tool_invocation_content,
)
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"

View file

@ -15,7 +15,8 @@ from llama_stack_client.types.memory_insert_params import Document
@pytest.fixture(scope="function")
def empty_memory_bank_registry(llama_stack_client):
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
for memory_bank_id in memory_banks:
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
@ -35,7 +36,8 @@ def single_entry_memory_bank_registry(llama_stack_client, empty_memory_bank_regi
provider_id="faiss",
)
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
return memory_banks
@ -104,7 +106,8 @@ def test_memory_bank_retrieve(llama_stack_client, empty_memory_bank_registry):
def test_memory_bank_list(llama_stack_client, empty_memory_bank_registry):
memory_banks_after_register = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
assert len(memory_banks_after_register) == 0
@ -124,14 +127,16 @@ def test_memory_bank_register(llama_stack_client, empty_memory_bank_registry):
)
memory_banks_after_register = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
assert memory_banks_after_register == [memory_bank_id]
def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_registry):
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
assert len(memory_banks) == 1
@ -139,7 +144,8 @@ def test_memory_bank_unregister(llama_stack_client, single_entry_memory_bank_reg
llama_stack_client.memory_banks.unregister(memory_bank_id=memory_bank_id)
memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
assert len(memory_banks) == 0
@ -195,11 +201,10 @@ def test_memory_bank_insert_inline_and_query(
def test_memory_bank_insert_from_url_and_query(
llama_stack_client, empty_memory_bank_registry
):
providers = llama_stack_client.providers.list()
assert "memory" in providers
assert len(providers["memory"]) > 0
providers = llama_stack_client.providers.list().memory
assert len(providers) > 0
memory_provider_id = providers["memory"][0].provider_id
memory_provider_id = providers[0]["provider_id"]
memory_bank_id = "test_bank"
llama_stack_client.memory_banks.register(
@ -215,7 +220,8 @@ def test_memory_bank_insert_from_url_and_query(
# list to check memory bank is successfully registered
available_memory_banks = [
memory_bank.identifier for memory_bank in llama_stack_client.memory_banks.list()
memory_bank.identifier
for memory_bank in llama_stack_client.memory_banks.list().data
]
assert memory_bank_id in available_memory_banks

View file

@ -30,7 +30,7 @@ def data_url_from_image(file_path):
@pytest.fixture(scope="session")
def available_shields(llama_stack_client):
return [shield.identifier for shield in llama_stack_client.shields.list()]
return [shield.identifier for shield in llama_stack_client.shields.list().data]
@pytest.fixture(scope="session")
@ -54,7 +54,7 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session")
def model_providers(llama_stack_client):
return set(
[x.provider_id for x in llama_stack_client.providers.list()["inference"]]
[x["provider_id"] for x in llama_stack_client.providers.list().inference]
)