mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chatcompletion & completion input type validation
This commit is contained in:
parent
e468e23249
commit
29e48cc5c1
5 changed files with 57 additions and 30 deletions
|
@ -46,6 +46,21 @@ class CustomType(BaseModel):
|
|||
validator_class: str
|
||||
|
||||
|
||||
class ChatCompletionInputType(BaseModel):
|
||||
# expects List[Message] for messages
|
||||
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||
|
||||
|
||||
class CompletionInputType(BaseModel):
|
||||
# expects InterleavedTextMedia for content
|
||||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
class AgentTurnInputType(BaseModel):
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
|
@ -56,6 +71,9 @@ ParamType = Annotated[
|
|||
JsonType,
|
||||
UnionType,
|
||||
CustomType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
||||
|
@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval):
|
|||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
# TODO: we will require user defined message types for ToolResponseMessage or include message.context
|
||||
# for now uses basic schema where messages={type: "user", content: "input_query"}
|
||||
for required_column in ["expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
expected_schemas = [
|
||||
{
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
},
|
||||
{
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": CompletionInputType(),
|
||||
},
|
||||
]
|
||||
|
||||
if dataset_def.dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def evaluate_batch(
|
||||
self,
|
||||
|
@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval):
|
|||
)
|
||||
generations = []
|
||||
for x in input_rows:
|
||||
input_query = x["input_query"]
|
||||
input_messages = eval(str(x["chat_completion_input"]))
|
||||
input_messages = [UserMessage(**x) for x in input_messages]
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages.append(
|
||||
UserMessage(content=input_query),
|
||||
)
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
input_query,generated_answer,expected_answer
|
||||
What is the capital of France?,London,Paris
|
||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
|
||||
What is the largest planet in our solar system?,Jupiter,Jupiter
|
||||
What is the smallest country in the world?,China,Vatican City
|
||||
What is the currency of Japan?,Yen,Yen
|
||||
input_query,generated_answer,expected_answer,chat_completion_input
|
||||
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
|
||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
|
||||
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
|
||||
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
|
||||
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"
|
||||
|
|
|
|
@ -62,17 +62,22 @@ def data_url_from_file(file_path: str) -> str:
|
|||
|
||||
|
||||
async def register_dataset(
|
||||
datasets_impl: Datasets, include_generated_answer=True, dataset_id="test_dataset"
|
||||
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||
):
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
}
|
||||
if include_generated_answer:
|
||||
dataset_schema["generated_answer"] = StringType()
|
||||
if for_generation:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"chat_completion_input": ChatCompletionInputType(),
|
||||
}
|
||||
else:
|
||||
dataset_schema = {
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
"generated_answer": StringType(),
|
||||
}
|
||||
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier=dataset_id,
|
||||
|
|
|
@ -51,7 +51,7 @@ async def test_eval(eval_settings):
|
|||
datasets_impl = eval_settings["datasets_impl"]
|
||||
await register_dataset(
|
||||
datasets_impl,
|
||||
include_generated_answer=False,
|
||||
for_generation=True,
|
||||
dataset_id="test_dataset_for_eval",
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue