This commit is contained in:
Xi Yan 2024-10-07 17:27:06 -07:00
parent 4764762dd4
commit 5b7d24b1c3
3 changed files with 14 additions and 16 deletions

View file

@ -44,7 +44,7 @@ async def run_main(host: str, port: int):
# CustomDataset # CustomDataset
response = await client.run_evals( response = await client.run_evals(
"Llama3.1-8B-Instruct", "Llama3.2-1B-Instruct",
"mmlu-simple-eval-en", "mmlu-simple-eval-en",
"mmlu", "mmlu",
) )

View file

@ -14,10 +14,6 @@ from llama_stack.providers.impls.meta_reference.evals.datas.dataset_registry imp
get_dataset, get_dataset,
) )
# from llama_stack.providers.impls.meta_reference.evals.tasks.task_registry import (
# get_task,
# )
from .config import MetaReferenceEvalsImplConfig from .config import MetaReferenceEvalsImplConfig
@ -45,7 +41,9 @@ class MetaReferenceEvalsImpl(Evals):
# TODO: replace w/ batch inference & async return eval job # TODO: replace w/ batch inference & async return eval job
generation_outputs = [] generation_outputs = []
print("generation start")
for msg in x1[:5]: for msg in x1[:5]:
print("generation for msg: ", msg)
response = self.inference_api.chat_completion( response = self.inference_api.chat_completion(
model=model, model=model,
messages=[msg], messages=[msg],

View file

@ -39,18 +39,18 @@ api_providers:
config: {} config: {}
routing_table: routing_table:
inference: inference:
# - provider_type: meta-reference - provider_type: meta-reference
# config:
# model: Llama3.2-1B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# routing_key: Llama3.2-1B-Instruct
- provider_type: remote::tgi
config: config:
url: http://127.0.0.1:5009 model: Llama3.2-1B-Instruct
routing_key: Llama3.1-8B-Instruct quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
routing_key: Llama3.2-1B-Instruct
# - provider_type: remote::tgi
# config:
# url: http://127.0.0.1:5009
# routing_key: Llama3.1-8B-Instruct
safety: safety:
- provider_type: meta-reference - provider_type: meta-reference
config: config: