fix chat rag providers

This commit is contained in:
Xi Yan 2025-01-16 18:53:23 -08:00
parent 9f14382d82
commit b66d2bb492
3 changed files with 27 additions and 18 deletions

View file

@ -10,11 +10,17 @@ from modules.api import llama_stack_api
def providers(): def providers():
st.header("🔍 API Providers") st.header("🔍 API Providers")
apis_providers_info = llama_stack_api.client.providers.list() apis_providers_lst = llama_stack_api.client.providers.list()
# selected_api = st.selectbox("Select an API", list(apis_providers_info.keys())) api_to_providers = {}
for api in apis_providers_info.keys(): for api_provider in apis_providers_lst:
if api_provider.api in api_to_providers:
api_to_providers[api_provider.api].append(api_provider)
else:
api_to_providers[api_provider.api] = [api_provider]
for api in api_to_providers.keys():
st.markdown(f"###### {api}") st.markdown(f"###### {api}")
st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500) st.dataframe([x.to_dict() for x in api_to_providers[api]], width=500)
providers() providers()

View file

@ -121,7 +121,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
if stream: if stream:
for chunk in response: for chunk in response:
if chunk.event.event_type == "progress": if chunk.event.event_type == "progress":
full_response += chunk.event.delta full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "") message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
else: else:

View file

@ -44,14 +44,21 @@ def rag_chat_page():
] ]
providers = llama_stack_api.client.providers.list() providers = llama_stack_api.client.providers.list()
memory_provider = None
for x in providers:
if x.api == "memory":
memory_provider = x.provider_id
llama_stack_api.client.memory_banks.register( llama_stack_api.client.memory_banks.register(
memory_bank_id=memory_bank_name, # Use the user-provided name memory_bank_id=memory_bank_name, # Use the user-provided name
params={ params={
"memory_bank_type": "vector",
"embedding_model": "all-MiniLM-L6-v2", "embedding_model": "all-MiniLM-L6-v2",
"chunk_size_in_tokens": 512, "chunk_size_in_tokens": 512,
"overlap_size_in_tokens": 64, "overlap_size_in_tokens": 64,
}, },
provider_id=providers["memory"][0].provider_id, provider_id=memory_provider,
) )
# insert documents using the custom bank name # insert documents using the custom bank name
@ -69,9 +76,6 @@ def rag_chat_page():
"Select Memory Banks", "Select Memory Banks",
memory_banks, memory_banks,
) )
memory_bank_configs = [
{"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks
]
available_models = llama_stack_api.client.models.list() available_models = llama_stack_api.client.models.list()
available_models = [ available_models = [
@ -133,14 +137,13 @@ def rag_chat_page():
sampling_params={ sampling_params={
"strategy": strategy, "strategy": strategy,
}, },
tools=[ toolgroups=[
{ dict(
"type": "memory", name="builtin::memory",
"memory_bank_configs": memory_bank_configs, args={
"query_generator_config": {"type": "default", "sep": " "}, "memory_bank_ids": [bank_id for bank_id in selected_memory_banks],
"max_tokens_in_context": 4096, },
"max_chunks": 10, )
}
], ],
tool_choice="auto", tool_choice="auto",
tool_prompt_format="json", tool_prompt_format="json",
@ -179,7 +182,7 @@ def rag_chat_page():
retrieval_response = "" retrieval_response = ""
for log in EventLogger().log(response): for log in EventLogger().log(response):
log.print() log.print()
if log.role == "memory_retrieval": if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip() retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response) retrieval_message_placeholder.info(retrieval_response)
else: else: