diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py index 44e314cde..b52356522 100644 --- a/llama_stack/distribution/ui/page/distribution/datasets.py +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -14,6 +14,6 @@ def datasets(): datasets_info = { d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list() } - - selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) - st.json(datasets_info[selected_dataset], expanded=True) + if len(datasets_info) > 0: + selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys())) + st.json(datasets_info[selected_dataset], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py index 4957fb178..cc7912838 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -16,7 +16,8 @@ def eval_tasks(): d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list() } - selected_eval_task = st.selectbox( - "Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect" - ) - st.json(eval_tasks_info[selected_eval_task], expanded=True) + if len(eval_tasks_info) > 0: + selected_eval_task = st.selectbox( + "Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect" + ) + st.json(eval_tasks_info[selected_eval_task], expanded=True) diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index 69f6bd771..9aeb7f2a5 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -10,11 +10,17 @@ from modules.api import llama_stack_api def providers(): st.header("🔍 API Providers") - apis_providers_info = llama_stack_api.client.providers.list() - # selected_api = st.selectbox("Select an API", list(apis_providers_info.keys())) - for api in apis_providers_info.keys(): + apis_providers_lst = llama_stack_api.client.providers.list() + api_to_providers = {} + for api_provider in apis_providers_lst: + if api_provider.api in api_to_providers: + api_to_providers[api_provider.api].append(api_provider) + else: + api_to_providers[api_provider.api] = [api_provider] + + for api in api_to_providers.keys(): st.markdown(f"###### {api}") - st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500) + st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500) providers() diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index 5d91ec819..cb9990b7c 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -121,7 +121,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"): if stream: for chunk in response: if chunk.event.event_type == "progress": - full_response += chunk.event.delta + full_response += chunk.event.delta.text message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) else: diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 3a2ba1270..11b05718d 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -44,14 +44,21 @@ def rag_chat_page(): ] providers = llama_stack_api.client.providers.list() + memory_provider = None + + for x in providers: + if x.api == "memory": + memory_provider = x.provider_id + llama_stack_api.client.memory_banks.register( memory_bank_id=memory_bank_name, # Use the user-provided name params={ + "memory_bank_type": "vector", "embedding_model": "all-MiniLM-L6-v2", "chunk_size_in_tokens": 512, "overlap_size_in_tokens": 64, }, - provider_id=providers["memory"][0].provider_id, + provider_id=memory_provider, ) # insert documents using the custom bank name @@ -69,9 +76,6 @@ def rag_chat_page(): "Select Memory Banks", memory_banks, ) - memory_bank_configs = [ - {"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks - ] available_models = llama_stack_api.client.models.list() available_models = [ @@ -133,14 +137,13 @@ def rag_chat_page(): sampling_params={ "strategy": strategy, }, - tools=[ - { - "type": "memory", - "memory_bank_configs": memory_bank_configs, - "query_generator_config": {"type": "default", "sep": " "}, - "max_tokens_in_context": 4096, - "max_chunks": 10, - } + toolgroups=[ + dict( + name="builtin::memory", + args={ + "memory_bank_ids": [bank_id for bank_id in selected_memory_banks], + }, + ) ], tool_choice="auto", tool_prompt_format="json", @@ -179,7 +182,7 @@ def rag_chat_page(): retrieval_response = "" for log in EventLogger().log(response): log.print() - if log.role == "memory_retrieval": + if log.role == "tool_execution": retrieval_response += log.content.replace("====", "").strip() retrieval_message_placeholder.info(retrieval_response) else: