chatcompletion & completion input type validation

This commit is contained in:
Xi Yan 2024-10-24 16:11:25 -07:00
parent e468e23249
commit 29e48cc5c1
5 changed files with 57 additions and 30 deletions

View file

@ -46,6 +46,21 @@ class CustomType(BaseModel):
validator_class: str validator_class: str
class ChatCompletionInputType(BaseModel):
# expects List[Message] for messages
type: Literal["chat_completion_input"] = "chat_completion_input"
class CompletionInputType(BaseModel):
# expects InterleavedTextMedia for content
type: Literal["completion_input"] = "completion_input"
class AgentTurnInputType(BaseModel):
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
ParamType = Annotated[ ParamType = Annotated[
Union[ Union[
StringType, StringType,
@ -56,6 +71,9 @@ ParamType = Annotated[
JsonType, JsonType,
UnionType, UnionType,
CustomType, CustomType,
ChatCompletionInputType,
CompletionInputType,
AgentTurnInputType,
], ],
Field(discriminator="type"), Field(discriminator="type"),
] ]

View file

@ -5,10 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
) )
# TODO: we will require user defined message types for ToolResponseMessage or include message.context expected_schemas = [
# for now uses basic schema where messages={type: "user", content: "input_query"} {
for required_column in ["expected_answer", "input_query"]: "expected_answer": StringType(),
if required_column not in dataset_def.dataset_schema: "chat_completion_input": ChatCompletionInputType(),
raise ValueError( },
f"Dataset {dataset_id} does not have a '{required_column}' column." {
) "expected_answer": StringType(),
if dataset_def.dataset_schema[required_column].type != "string": "chat_completion_input": CompletionInputType(),
raise ValueError( },
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." ]
)
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch( async def evaluate_batch(
self, self,
@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval):
) )
generations = [] generations = []
for x in input_rows: for x in input_rows:
input_query = x["input_query"] input_messages = eval(str(x["chat_completion_input"]))
input_messages = [UserMessage(**x) for x in input_messages]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)
messages.append( messages += input_messages
UserMessage(content=input_query),
)
response = await self.inference_api.chat_completion( response = await self.inference_api.chat_completion(
model=candidate.model, model=candidate.model,
messages=messages, messages=messages,

View file

@ -1,6 +1,6 @@
input_query,generated_answer,expected_answer input_query,generated_answer,expected_answer,chat_completion_input
What is the capital of France?,London,Paris What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
What is the largest planet in our solar system?,Jupiter,Jupiter What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
What is the smallest country in the world?,China,Vatican City What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
What is the currency of Japan?,Yen,Yen What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"

1 input_query generated_answer expected_answer chat_completion_input
2 What is the capital of France? London Paris [{'role': 'user', 'content': 'What is the capital of France?'}]
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg [{'role': 'user', 'content': 'Who is the CEO of Meta?'}]
4 What is the largest planet in our solar system? Jupiter Jupiter [{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]
5 What is the smallest country in the world? China Vatican City [{'role': 'user', 'content': 'What is the smallest country in the world?'}]
6 What is the currency of Japan? Yen Yen [{'role': 'user', 'content': 'What is the currency of Japan?'}]

View file

@ -62,17 +62,22 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset( async def register_dataset(
datasets_impl: Datasets, include_generated_answer=True, dataset_id="test_dataset" datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
): ):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file)) test_url = data_url_from_file(str(test_file))
dataset_schema = { if for_generation:
"expected_answer": StringType(), dataset_schema = {
"input_query": StringType(), "expected_answer": StringType(),
} "chat_completion_input": ChatCompletionInputType(),
if include_generated_answer: }
dataset_schema["generated_answer"] = StringType() else:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
}
dataset = DatasetDefWithProvider( dataset = DatasetDefWithProvider(
identifier=dataset_id, identifier=dataset_id,

View file

@ -51,7 +51,7 @@ async def test_eval(eval_settings):
datasets_impl = eval_settings["datasets_impl"] datasets_impl = eval_settings["datasets_impl"]
await register_dataset( await register_dataset(
datasets_impl, datasets_impl,
include_generated_answer=False, for_generation=True,
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
) )