mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-01 19:44:30 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -58,11 +58,6 @@ def define_eval_candidate_2():
|
|||
|
||||
# Sampling Parameters
|
||||
st.markdown("##### Sampling Parameters")
|
||||
strategy = st.selectbox(
|
||||
"Strategy",
|
||||
["greedy", "top_p", "top_k"],
|
||||
index=0,
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
|
|
@ -95,13 +90,20 @@ def define_eval_candidate_2():
|
|||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
if candidate_type == "model":
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
eval_candidate = {
|
||||
"type": "model",
|
||||
"model": selected_model,
|
||||
"sampling_params": {
|
||||
"strategy": strategy,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -95,6 +95,15 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
|
|
@ -103,8 +112,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
model_id=selected_model,
|
||||
stream=stream,
|
||||
sampling_params={
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"strategy": strategy,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -118,13 +118,20 @@ def rag_chat_page():
|
|||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
"type": "top_p",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": "greedy",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue