chatcompletion & completion input type validation

This commit is contained in:
Xi Yan 2024-10-24 16:11:25 -07:00
parent e468e23249
commit 29e48cc5c1
5 changed files with 57 additions and 30 deletions

View file

@ -46,6 +46,21 @@ class CustomType(BaseModel):
validator_class: str
class ChatCompletionInputType(BaseModel):
# expects List[Message] for messages
type: Literal["chat_completion_input"] = "chat_completion_input"
class CompletionInputType(BaseModel):
# expects InterleavedTextMedia for content
type: Literal["completion_input"] = "completion_input"
class AgentTurnInputType(BaseModel):
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
ParamType = Annotated[
Union[
StringType,
@ -56,6 +71,9 @@ ParamType = Annotated[
JsonType,
UnionType,
CustomType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
],
Field(discriminator="type"),
]

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
# TODO: we will require user defined message types for ToolResponseMessage or include message.context
# for now uses basic schema where messages={type: "user", content: "input_query"}
for required_column in ["expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
expected_schemas = [
{
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
},
{
"expected_answer": StringType(),
"chat_completion_input": CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch(
self,
@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval):
)
generations = []
for x in input_rows:
input_query = x["input_query"]
input_messages = eval(str(x["chat_completion_input"]))
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages.append(
UserMessage(content=input_query),
)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
messages=messages,

View file

@ -1,6 +1,6 @@
input_query,generated_answer,expected_answer
What is the capital of France?,London,Paris
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,Jupiter,Jupiter
What is the smallest country in the world?,China,Vatican City
What is the currency of Japan?,Yen,Yen
input_query,generated_answer,expected_answer,chat_completion_input
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"

1 input_query generated_answer expected_answer chat_completion_input
2 What is the capital of France? London Paris [{'role': 'user', 'content': 'What is the capital of France?'}]
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg [{'role': 'user', 'content': 'Who is the CEO of Meta?'}]
4 What is the largest planet in our solar system? Jupiter Jupiter [{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]
5 What is the smallest country in the world? China Vatican City [{'role': 'user', 'content': 'What is the smallest country in the world?'}]
6 What is the currency of Japan? Yen Yen [{'role': 'user', 'content': 'What is the currency of Japan?'}]

View file

@ -62,17 +62,22 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset(
datasets_impl: Datasets, include_generated_answer=True, dataset_id="test_dataset"
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
}
if include_generated_answer:
dataset_schema["generated_answer"] = StringType()
if for_generation:
dataset_schema = {
"expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(),
}
else:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
}
dataset = DatasetDefWithProvider(
identifier=dataset_id,

View file

@ -51,7 +51,7 @@ async def test_eval(eval_settings):
datasets_impl = eval_settings["datasets_impl"]
await register_dataset(
datasets_impl,
include_generated_answer=False,
for_generation=True,
dataset_id="test_dataset_for_eval",
)