fix playground for v1 (#799)

# What does this PR do?

- update playground callsites for v1 api changes

## Test Plan

```
cd llama_stack/distribution/ui
streamlit run app.py
```


https://github.com/user-attachments/assets/eace11c6-600a-42dc-b4e7-6948a706509f




## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-16 19:32:07 -08:00 committed by GitHub
parent b2ac29b9da
commit 9d574f4aee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 25 deletions

View file

@ -14,6 +14,6 @@ def datasets():
datasets_info = { datasets_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list() d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
} }
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True) st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -16,7 +16,8 @@ def eval_tasks():
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list() d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
} }
selected_eval_task = st.selectbox( if len(eval_tasks_info) > 0:
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect" selected_eval_task = st.selectbox(
) "Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
st.json(eval_tasks_info[selected_eval_task], expanded=True) )
st.json(eval_tasks_info[selected_eval_task], expanded=True)

View file

@ -10,11 +10,17 @@ from modules.api import llama_stack_api
def providers(): def providers():
st.header("🔍 API Providers") st.header("🔍 API Providers")
apis_providers_info = llama_stack_api.client.providers.list() apis_providers_lst = llama_stack_api.client.providers.list()
# selected_api = st.selectbox("Select an API", list(apis_providers_info.keys())) api_to_providers = {}
for api in apis_providers_info.keys(): for api_provider in apis_providers_lst:
if api_provider.api in api_to_providers:
api_to_providers[api_provider.api].append(api_provider)
else:
api_to_providers[api_provider.api] = [api_provider]
for api in api_to_providers.keys():
st.markdown(f"###### {api}") st.markdown(f"###### {api}")
st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500) st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
providers() providers()

View file

@ -121,7 +121,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
if stream: if stream:
for chunk in response: for chunk in response:
if chunk.event.event_type == "progress": if chunk.event.event_type == "progress":
full_response += chunk.event.delta full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "") message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
else: else:

View file

@ -44,14 +44,21 @@ def rag_chat_page():
] ]
providers = llama_stack_api.client.providers.list() providers = llama_stack_api.client.providers.list()
memory_provider = None
for x in providers:
if x.api == "memory":
memory_provider = x.provider_id
llama_stack_api.client.memory_banks.register( llama_stack_api.client.memory_banks.register(
memory_bank_id=memory_bank_name, # Use the user-provided name memory_bank_id=memory_bank_name, # Use the user-provided name
params={ params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2", "embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512, "chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64, "overlap_size_in_tokens": 64,
}, },
provider_id=providers["memory"][0].provider_id, provider_id=memory_provider,
) )
# insert documents using the custom bank name # insert documents using the custom bank name
@ -69,9 +76,6 @@ def rag_chat_page():
"Select Memory Banks", "Select Memory Banks",
memory_banks, memory_banks,
) )
memory_bank_configs = [
{"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks
]
available_models = llama_stack_api.client.models.list() available_models = llama_stack_api.client.models.list()
available_models = [ available_models = [
@ -133,14 +137,13 @@ def rag_chat_page():
sampling_params={ sampling_params={
"strategy": strategy, "strategy": strategy,
}, },
tools=[ toolgroups=[
{ dict(
"type": "memory", name="builtin::memory",
"memory_bank_configs": memory_bank_configs, args={
"query_generator_config": {"type": "default", "sep": " "}, "memory_bank_ids": [bank_id for bank_id in selected_memory_banks],
"max_tokens_in_context": 4096, },
"max_chunks": 10, )
}
], ],
tool_choice="auto", tool_choice="auto",
tool_prompt_format="json", tool_prompt_format="json",
@ -179,7 +182,7 @@ def rag_chat_page():
retrieval_response = "" retrieval_response = ""
for log in EventLogger().log(response): for log in EventLogger().log(response):
log.print() log.print()
if log.role == "memory_retrieval": if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip() retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response) retrieval_message_placeholder.info(retrieval_response)
else: else: