mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
fix playground for v1 (#799)
# What does this PR do? - update playground callsites for v1 api changes ## Test Plan ``` cd llama_stack/distribution/ui streamlit run app.py ``` https://github.com/user-attachments/assets/eace11c6-600a-42dc-b4e7-6948a706509f ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
b2ac29b9da
commit
9d574f4aee
5 changed files with 35 additions and 25 deletions
|
@ -14,6 +14,6 @@ def datasets():
|
|||
datasets_info = {
|
||||
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
|
||||
}
|
||||
|
||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||
st.json(datasets_info[selected_dataset], expanded=True)
|
||||
if len(datasets_info) > 0:
|
||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||
st.json(datasets_info[selected_dataset], expanded=True)
|
||||
|
|
|
@ -16,7 +16,8 @@ def eval_tasks():
|
|||
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
|
||||
}
|
||||
|
||||
selected_eval_task = st.selectbox(
|
||||
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
|
||||
)
|
||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
||||
if len(eval_tasks_info) > 0:
|
||||
selected_eval_task = st.selectbox(
|
||||
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
|
||||
)
|
||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
||||
|
|
|
@ -10,11 +10,17 @@ from modules.api import llama_stack_api
|
|||
|
||||
def providers():
|
||||
st.header("🔍 API Providers")
|
||||
apis_providers_info = llama_stack_api.client.providers.list()
|
||||
# selected_api = st.selectbox("Select an API", list(apis_providers_info.keys()))
|
||||
for api in apis_providers_info.keys():
|
||||
apis_providers_lst = llama_stack_api.client.providers.list()
|
||||
api_to_providers = {}
|
||||
for api_provider in apis_providers_lst:
|
||||
if api_provider.api in api_to_providers:
|
||||
api_to_providers[api_provider.api].append(api_provider)
|
||||
else:
|
||||
api_to_providers[api_provider.api] = [api_provider]
|
||||
|
||||
for api in api_to_providers.keys():
|
||||
st.markdown(f"###### {api}")
|
||||
st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500)
|
||||
st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
|
||||
|
||||
|
||||
providers()
|
||||
|
|
|
@ -121,7 +121,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
if stream:
|
||||
for chunk in response:
|
||||
if chunk.event.event_type == "progress":
|
||||
full_response += chunk.event.delta
|
||||
full_response += chunk.event.delta.text
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
else:
|
||||
|
|
|
@ -44,14 +44,21 @@ def rag_chat_page():
|
|||
]
|
||||
|
||||
providers = llama_stack_api.client.providers.list()
|
||||
memory_provider = None
|
||||
|
||||
for x in providers:
|
||||
if x.api == "memory":
|
||||
memory_provider = x.provider_id
|
||||
|
||||
llama_stack_api.client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_name, # Use the user-provided name
|
||||
params={
|
||||
"memory_bank_type": "vector",
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
},
|
||||
provider_id=providers["memory"][0].provider_id,
|
||||
provider_id=memory_provider,
|
||||
)
|
||||
|
||||
# insert documents using the custom bank name
|
||||
|
@ -69,9 +76,6 @@ def rag_chat_page():
|
|||
"Select Memory Banks",
|
||||
memory_banks,
|
||||
)
|
||||
memory_bank_configs = [
|
||||
{"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks
|
||||
]
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [
|
||||
|
@ -133,14 +137,13 @@ def rag_chat_page():
|
|||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
{
|
||||
"type": "memory",
|
||||
"memory_bank_configs": memory_bank_configs,
|
||||
"query_generator_config": {"type": "default", "sep": " "},
|
||||
"max_tokens_in_context": 4096,
|
||||
"max_chunks": 10,
|
||||
}
|
||||
toolgroups=[
|
||||
dict(
|
||||
name="builtin::memory",
|
||||
args={
|
||||
"memory_bank_ids": [bank_id for bank_id in selected_memory_banks],
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
|
@ -179,7 +182,7 @@ def rag_chat_page():
|
|||
retrieval_response = ""
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "memory_retrieval":
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue