fix playground for v1 (#799)

# What does this PR do?

- update playground callsites for v1 api changes

## Test Plan

```
cd llama_stack/distribution/ui
streamlit run app.py
```


https://github.com/user-attachments/assets/eace11c6-600a-42dc-b4e7-6948a706509f




## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-16 19:32:07 -08:00 committed by GitHub
parent b2ac29b9da
commit 9d574f4aee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 25 deletions

View file

@ -14,6 +14,6 @@ def datasets():
datasets_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
}
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -16,6 +16,7 @@ def eval_tasks():
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
}
if len(eval_tasks_info) > 0:
selected_eval_task = st.selectbox(
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
)

View file

@ -10,11 +10,17 @@ from modules.api import llama_stack_api
def providers():
st.header("🔍 API Providers")
apis_providers_info = llama_stack_api.client.providers.list()
# selected_api = st.selectbox("Select an API", list(apis_providers_info.keys()))
for api in apis_providers_info.keys():
apis_providers_lst = llama_stack_api.client.providers.list()
api_to_providers = {}
for api_provider in apis_providers_lst:
if api_provider.api in api_to_providers:
api_to_providers[api_provider.api].append(api_provider)
else:
api_to_providers[api_provider.api] = [api_provider]
for api in api_to_providers.keys():
st.markdown(f"###### {api}")
st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500)
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
providers()

View file

@ -121,7 +121,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
if stream:
for chunk in response:
if chunk.event.event_type == "progress":
full_response += chunk.event.delta
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:

View file

@ -44,14 +44,21 @@ def rag_chat_page():
]
providers = llama_stack_api.client.providers.list()
memory_provider = None
for x in providers:
if x.api == "memory":
memory_provider = x.provider_id
llama_stack_api.client.memory_banks.register(
memory_bank_id=memory_bank_name, # Use the user-provided name
params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64,
},
provider_id=providers["memory"][0].provider_id,
provider_id=memory_provider,
)
# insert documents using the custom bank name
@ -69,9 +76,6 @@ def rag_chat_page():
"Select Memory Banks",
memory_banks,
)
memory_bank_configs = [
{"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks
]
available_models = llama_stack_api.client.models.list()
available_models = [
@ -133,14 +137,13 @@ def rag_chat_page():
sampling_params={
"strategy": strategy,
},
tools=[
{
"type": "memory",
"memory_bank_configs": memory_bank_configs,
"query_generator_config": {"type": "default", "sep": " "},
"max_tokens_in_context": 4096,
"max_chunks": 10,
}
toolgroups=[
dict(
name="builtin::memory",
args={
"memory_bank_ids": [bank_id for bank_id in selected_memory_banks],
},
)
],
tool_choice="auto",
tool_prompt_format="json",
@ -179,7 +182,7 @@ def rag_chat_page():
retrieval_response = ""
for log in EventLogger().log(response):
log.print()
if log.role == "memory_retrieval":
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else: