mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chatcompletion & completion input type validation
This commit is contained in:
parent
e468e23249
commit
29e48cc5c1
5 changed files with 57 additions and 30 deletions
|
@ -46,6 +46,21 @@ class CustomType(BaseModel):
|
||||||
validator_class: str
|
validator_class: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionInputType(BaseModel):
|
||||||
|
# expects List[Message] for messages
|
||||||
|
type: Literal["chat_completion_input"] = "chat_completion_input"
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionInputType(BaseModel):
|
||||||
|
# expects InterleavedTextMedia for content
|
||||||
|
type: Literal["completion_input"] = "completion_input"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTurnInputType(BaseModel):
|
||||||
|
# expects List[Message] for messages (may also include attachments?)
|
||||||
|
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||||
|
|
||||||
|
|
||||||
ParamType = Annotated[
|
ParamType = Annotated[
|
||||||
Union[
|
Union[
|
||||||
StringType,
|
StringType,
|
||||||
|
@ -56,6 +71,9 @@ ParamType = Annotated[
|
||||||
JsonType,
|
JsonType,
|
||||||
UnionType,
|
UnionType,
|
||||||
CustomType,
|
CustomType,
|
||||||
|
ChatCompletionInputType,
|
||||||
|
CompletionInputType,
|
||||||
|
AgentTurnInputType,
|
||||||
],
|
],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,10 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.eval import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import Job
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
|
|
||||||
|
@ -44,17 +45,21 @@ class MetaReferenceEvalImpl(Eval):
|
||||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: we will require user defined message types for ToolResponseMessage or include message.context
|
expected_schemas = [
|
||||||
# for now uses basic schema where messages={type: "user", content: "input_query"}
|
{
|
||||||
for required_column in ["expected_answer", "input_query"]:
|
"expected_answer": StringType(),
|
||||||
if required_column not in dataset_def.dataset_schema:
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
raise ValueError(
|
},
|
||||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
{
|
||||||
)
|
"expected_answer": StringType(),
|
||||||
if dataset_def.dataset_schema[required_column].type != "string":
|
"chat_completion_input": CompletionInputType(),
|
||||||
raise ValueError(
|
},
|
||||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
]
|
||||||
)
|
|
||||||
|
if dataset_def.dataset_schema not in expected_schemas:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||||
|
)
|
||||||
|
|
||||||
async def evaluate_batch(
|
async def evaluate_batch(
|
||||||
self,
|
self,
|
||||||
|
@ -91,13 +96,12 @@ class MetaReferenceEvalImpl(Eval):
|
||||||
)
|
)
|
||||||
generations = []
|
generations = []
|
||||||
for x in input_rows:
|
for x in input_rows:
|
||||||
input_query = x["input_query"]
|
input_messages = eval(str(x["chat_completion_input"]))
|
||||||
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
messages = []
|
messages = []
|
||||||
if candidate.system_message:
|
if candidate.system_message:
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
messages.append(
|
messages += input_messages
|
||||||
UserMessage(content=input_query),
|
|
||||||
)
|
|
||||||
response = await self.inference_api.chat_completion(
|
response = await self.inference_api.chat_completion(
|
||||||
model=candidate.model,
|
model=candidate.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
input_query,generated_answer,expected_answer
|
input_query,generated_answer,expected_answer,chat_completion_input
|
||||||
What is the capital of France?,London,Paris
|
What is the capital of France?,London,Paris,"[{'role': 'user', 'content': 'What is the capital of France?'}]"
|
||||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
|
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg,"[{'role': 'user', 'content': 'Who is the CEO of Meta?'}]"
|
||||||
What is the largest planet in our solar system?,Jupiter,Jupiter
|
What is the largest planet in our solar system?,Jupiter,Jupiter,"[{'role': 'user', 'content': 'What is the largest planet in our solar system?'}]"
|
||||||
What is the smallest country in the world?,China,Vatican City
|
What is the smallest country in the world?,China,Vatican City,"[{'role': 'user', 'content': 'What is the smallest country in the world?'}]"
|
||||||
What is the currency of Japan?,Yen,Yen
|
What is the currency of Japan?,Yen,Yen,"[{'role': 'user', 'content': 'What is the currency of Japan?'}]"
|
||||||
|
|
|
|
@ -62,17 +62,22 @@ def data_url_from_file(file_path: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
async def register_dataset(
|
async def register_dataset(
|
||||||
datasets_impl: Datasets, include_generated_answer=True, dataset_id="test_dataset"
|
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset"
|
||||||
):
|
):
|
||||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||||
test_url = data_url_from_file(str(test_file))
|
test_url = data_url_from_file(str(test_file))
|
||||||
|
|
||||||
dataset_schema = {
|
if for_generation:
|
||||||
"expected_answer": StringType(),
|
dataset_schema = {
|
||||||
"input_query": StringType(),
|
"expected_answer": StringType(),
|
||||||
}
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
if include_generated_answer:
|
}
|
||||||
dataset_schema["generated_answer"] = StringType()
|
else:
|
||||||
|
dataset_schema = {
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
|
"generated_answer": StringType(),
|
||||||
|
}
|
||||||
|
|
||||||
dataset = DatasetDefWithProvider(
|
dataset = DatasetDefWithProvider(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
|
|
|
@ -51,7 +51,7 @@ async def test_eval(eval_settings):
|
||||||
datasets_impl = eval_settings["datasets_impl"]
|
datasets_impl = eval_settings["datasets_impl"]
|
||||||
await register_dataset(
|
await register_dataset(
|
||||||
datasets_impl,
|
datasets_impl,
|
||||||
include_generated_answer=False,
|
for_generation=True,
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue