This commit is contained in:
Xi Yan 2024-12-02 13:06:36 -08:00
parent 9bceb1912e
commit 7f2ed9622c
4 changed files with 8 additions and 12 deletions

View file

@ -22,14 +22,6 @@ class LlamaStackApi:
}, },
) )
def list_scoring_functions(self):
"""List all available scoring functions"""
return self.client.scoring_functions.list()
def list_models(self):
"""List all available judge models"""
return self.client.models.list()
def run_scoring( def run_scoring(
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict] self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
): ):

View file

@ -40,7 +40,7 @@ def application_evaluation_page():
# Select Scoring Functions to Run Evaluation On # Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions") st.subheader("Select Scoring Functions")
scoring_functions = llama_stack_api.list_scoring_functions() scoring_functions = llama_stack_api.client.scoring_functions.list()
scoring_functions = {sf.identifier: sf for sf in scoring_functions} scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys()) scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect( selected_scoring_functions = st.multiselect(
@ -49,7 +49,7 @@ def application_evaluation_page():
help="Choose one or more scoring functions.", help="Choose one or more scoring functions.",
) )
available_models = llama_stack_api.list_models() available_models = llama_stack_api.client.models.list()
available_models = [m.identifier for m in available_models] available_models = [m.identifier for m in available_models]
scoring_params = {} scoring_params = {}

View file

@ -10,7 +10,7 @@ from modules.api import llama_stack_api
# Sidebar configurations # Sidebar configurations
with st.sidebar: with st.sidebar:
st.header("Configuration") st.header("Configuration")
available_models = llama_stack_api.list_models() available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models] available_models = [model.identifier for model in available_models]
selected_model = st.selectbox( selected_model = st.selectbox(
"Choose a model", "Choose a model",

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn( llm_as_judge_base = ScoringFn(
@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn(
return_type=NumberType(), return_type=NumberType(),
provider_id="llm-as-judge", provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base", provider_resource_id="llm-as-judge-base",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template="Enter custom LLM as Judge Prompt Template",
),
) )