docs: api documentation for agents/eval/scoring/datasets (#1400)

# What does this PR do?

- add some docs to OpenAPI for agents/eval/scoring/datasetio

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
- read

[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-05 09:40:24 -08:00 committed by GitHub
parent 0d18274d34
commit 3d9331840e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 586 additions and 137 deletions

View file

@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: Optional[datetime] = None
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
class StepType(Enum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
@ -66,6 +94,11 @@ class StepType(Enum):
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation]
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
@ -335,7 +386,13 @@ class Agents(Protocol):
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
async def create_agent_turn(
@ -352,7 +409,19 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
"""
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -388,7 +457,15 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn: ...
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
@ -400,14 +477,30 @@ class Agents(Protocol):
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse: ...
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse: ...
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
async def get_agents_session(
@ -415,17 +508,35 @@ class Agents(Protocol):
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None: ...
) -> None:
"""Delete an agent session by its ID.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE")
async def delete_agent(
self,
agent_id: str,
) -> None: ...
) -> None:
"""Delete an agent by its ID.
:param agent_id: The ID of the agent to delete.
"""
...

View file

@ -14,6 +14,14 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class PaginatedRowsResult(BaseModel):
"""
A paginated list of rows from a dataset.
:param rows: The rows in the current page.
:param total_count: The total number of rows in the dataset.
:param next_page_token: The token to get the next page of rows.
"""
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
@ -36,7 +44,15 @@ class DatasetIO(Protocol):
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
) -> PaginatedRowsResult:
"""Get a paginated list of rows from a dataset.
:param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page.
:param page_token: The token to get the next page of rows.
:param filter_condition: (Optional) A condition to filter the rows by.
"""
...
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
@ -27,6 +34,11 @@ class ModelCandidate(BaseModel):
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
@ -39,6 +51,13 @@ EvalCandidate = register_schema(
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
@ -53,18 +72,32 @@ class BenchmarkConfig(BaseModel):
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job: ...
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
@ -73,13 +106,40 @@ class Eval(Protocol):
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ...
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
@json_schema_type
class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name
results: Dict[str, ScoringResult]
@ -55,4 +68,11 @@ class Scoring(Protocol):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ...
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
"""
...