mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
address comments
This commit is contained in:
parent
6b0baa6d53
commit
52fe165db8
2 changed files with 50 additions and 26 deletions
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from enum import Enum
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
@ -16,6 +17,13 @@ from llama_stack.apis.scoring import Scoring
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnName(Enum):
|
||||||
|
expected_answer = "expected_answer"
|
||||||
|
chat_completion_input = "chat_completion_input"
|
||||||
|
completion_input = "completion_input"
|
||||||
|
generated_answer = "generated_answer"
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceEvalImpl(Eval):
|
class MetaReferenceEvalImpl(Eval):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -41,18 +49,16 @@ class MetaReferenceEvalImpl(Eval):
|
||||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
raise ValueError(
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_schemas = [
|
expected_schemas = [
|
||||||
{
|
{
|
||||||
"expected_answer": StringType(),
|
ColumnName.expected_answer.value: StringType(),
|
||||||
"chat_completion_input": ChatCompletionInputType(),
|
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"expected_answer": StringType(),
|
ColumnName.expected_answer.value: StringType(),
|
||||||
"chat_completion_input": CompletionInputType(),
|
ColumnName.completion_input.value: CompletionInputType(),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -94,27 +100,43 @@ class MetaReferenceEvalImpl(Eval):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Evaluation with generation has not been implemented for agents"
|
"Evaluation with generation has not been implemented for agents"
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
candidate.sampling_params.max_tokens is not None
|
||||||
|
), "SamplingParams.max_tokens must be provided"
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for x in input_rows:
|
for x in input_rows:
|
||||||
if "completion_input" in x:
|
if ColumnName.completion_input.value in x:
|
||||||
raise NotImplementedError(
|
input_content = eval(str(x[ColumnName.completion_input.value]))
|
||||||
"Evaluation with completion API has not been implemented"
|
response = await self.inference_api.completion(
|
||||||
|
model=candidate.model,
|
||||||
|
content=input_content,
|
||||||
|
sampling_params=candidate.sampling_params,
|
||||||
)
|
)
|
||||||
|
generations.append(
|
||||||
input_messages = eval(str(x["chat_completion_input"]))
|
{
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
ColumnName.generated_answer.value: response.completion_message.content
|
||||||
messages = []
|
}
|
||||||
if candidate.system_message:
|
)
|
||||||
messages.append(candidate.system_message)
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
messages += input_messages
|
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
||||||
response = await self.inference_api.chat_completion(
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
model=candidate.model,
|
messages = []
|
||||||
messages=messages,
|
if candidate.system_message:
|
||||||
sampling_params=candidate.sampling_params,
|
messages.append(candidate.system_message)
|
||||||
)
|
messages += input_messages
|
||||||
generations.append(
|
response = await self.inference_api.chat_completion(
|
||||||
{"generated_answer": response.completion_message.content}
|
model=candidate.model,
|
||||||
)
|
messages=messages,
|
||||||
|
sampling_params=candidate.sampling_params,
|
||||||
|
)
|
||||||
|
generations.append(
|
||||||
|
{
|
||||||
|
ColumnName.generated_answer.value: response.completion_message.content
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid input row")
|
||||||
|
|
||||||
# scoring with generated_answer
|
# scoring with generated_answer
|
||||||
score_input_rows = [
|
score_input_rows = [
|
||||||
|
@ -132,6 +154,8 @@ class MetaReferenceEvalImpl(Eval):
|
||||||
if job_id in self.jobs:
|
if job_id in self.jobs:
|
||||||
return JobStatus.completed
|
return JobStatus.completed
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def job_cancel(self, job_id: str) -> None:
|
async def job_cancel(self, job_id: str) -> None:
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ async def test_eval(eval_settings):
|
||||||
response = await eval_impl.evaluate_batch(
|
response = await eval_impl.evaluate_batch(
|
||||||
dataset_id=response[0].identifier,
|
dataset_id=response[0].identifier,
|
||||||
candidate=ModelCandidate(
|
candidate=ModelCandidate(
|
||||||
model="Llama3.1-8B-Instruct",
|
model="Llama3.2-1B-Instruct",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
),
|
),
|
||||||
scoring_functions=["subset_of"],
|
scoring_functions=["subset_of"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue