diff --git a/source/api_definitions.py b/source/api_definitions.py
index 266bbcf3c..4dded2166 100644
--- a/source/api_definitions.py
+++ b/source/api_definitions.py
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum
-from typing import Any, Dict, List, Optional, Protocol, Set, Union
+from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
import yaml
from agentic_system_types import (
@@ -17,8 +17,8 @@ from model_types import (
Message,
PretrainedModel,
SamplingParams,
- StopReason,
ShieldConfig,
+ StopReason,
ToolCall,
ToolDefinition,
ToolResponse,
@@ -203,19 +203,37 @@ class AgenticSystem(Protocol):
) -> None: ...
+@dataclass
+class PromptGeneration:
+ prompt: Message
+ message_history: List[Message]
+ generation: Message
+
+
+@dataclass
+class ScoredPromptGeneration:
+ prompt_generation: PromptGeneration
+ score: float
+
+
@json_schema_type
@dataclass
class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
- prompt_responses: List[Tuple(prompt:str, List[response:str])]
+ prompt_generations: List[PromptGeneration]
+
+ # TODO(ragho): create a RewardModel enum tye
+ model: str
+
@json_schema_type
@dataclass
class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
- prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
+ scored_generations: List[ScoredPromptGeneration]
+
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
@@ -224,6 +242,7 @@ class RewardScoring(Protocol):
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...
+
class FilteringFunction(Enum):
"""The type of filtering function."""
@@ -234,6 +253,7 @@ class FilteringFunction(Enum):
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
+
@json_schema_type
@dataclass
class SyntheticDataGenerationRequest:
@@ -241,18 +261,23 @@ class SyntheticDataGenerationRequest:
prompts: List[str]
filtering_function: FilteringFunction = FilteringFunction.none
- reward_scoring: RewardScoring
+
+ # TODO(ragho): fix this
+ # reward_scoring: RewardScoring
+
@json_schema_type
@dataclass
class SyntheticDataGenerationResponse:
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
- synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
+ synthetic_data: List[Tuple[str, str, float]]
+
"""The actual synthetic data"""
statistics: Dict[str, float]
"""Statistics on how many prompts were generated and how many were filtered out"""
+
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def post_generate(
@@ -261,8 +286,9 @@ class SyntheticDataGeneration(Protocol):
) -> Union[SyntheticDataGenerationResponse]: ...
-
-class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
+class LlamaStackEndpoints(
+ Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
+): ...
if __name__ == "__main__":
diff --git a/source/openapi.html b/source/openapi.html
index e52e21643..8f5f935d3 100644
--- a/source/openapi.html
+++ b/source/openapi.html
@@ -191,6 +191,66 @@
"required": true
}
}
+ },
+ "/synthetic_data_generation/generate": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/SyntheticDataGenerationResponse"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "SyntheticDataGeneration"
+ ],
+ "parameters": [],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/SyntheticDataGenerationRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
+ },
+ "/reward_scoring/score": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/RewardScoringResponse"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ "RewardScoring"
+ ],
+ "parameters": [],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/RewardScoringRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@@ -1451,6 +1511,161 @@
"text_delta"
],
"title": "streamed completion response."
+ },
+ "SyntheticDataGenerationRequest": {
+ "type": "object",
+ "properties": {
+ "prompts": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "filtering_function": {
+ "type": "string",
+ "enum": [
+ "none",
+ "random",
+ "top_k",
+ "top_p",
+ "top_k_top_p",
+ "sigmoid"
+ ],
+ "title": "The type of filtering function.",
+ "default": "none"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompts",
+ "filtering_function"
+ ],
+ "title": "Request to generate synthetic data. A small batch of prompts and a filtering function"
+ },
+ "SyntheticDataGenerationResponse": {
+ "type": "object",
+ "properties": {
+ "synthetic_data": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "minItems": 3,
+ "maxItems": 3,
+ "prefixItems": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "number"
+ }
+ ]
+ }
+ },
+ "statistics": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "number"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "synthetic_data",
+ "statistics"
+ ],
+ "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
+ },
+ "RewardScoringRequest": {
+ "type": "object",
+ "properties": {
+ "prompt_generations": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "prompt": {
+ "$ref": "#/components/schemas/Message"
+ },
+ "message_history": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Message"
+ }
+ },
+ "generation": {
+ "$ref": "#/components/schemas/Message"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompt",
+ "message_history",
+ "generation"
+ ]
+ }
+ },
+ "model": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompt_generations",
+ "model"
+ ],
+ "title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
+ },
+ "RewardScoringResponse": {
+ "type": "object",
+ "properties": {
+ "scored_generations": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "prompt_generation": {
+ "type": "object",
+ "properties": {
+ "prompt": {
+ "$ref": "#/components/schemas/Message"
+ },
+ "message_history": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Message"
+ }
+ },
+ "generation": {
+ "$ref": "#/components/schemas/Message"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompt",
+ "message_history",
+ "generation"
+ ]
+ },
+ "score": {
+ "type": "number"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompt_generation",
+ "score"
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "scored_generations"
+ ],
+ "title": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."
}
},
"responses": {}
@@ -1462,11 +1677,17 @@
],
"tags": [
{
- "name": "AgenticSystem"
+ "name": "RewardScoring"
},
{
"name": "Inference"
},
+ {
+ "name": "SyntheticDataGeneration"
+ },
+ {
+ "name": "AgenticSystem"
+ },
{
"name": "ShieldConfig",
"description": ""
@@ -1530,6 +1751,22 @@
{
"name": "CompletionResponseStreamChunk",
"description": "streamed completion response.\n\n"
+ },
+ {
+ "name": "SyntheticDataGenerationRequest",
+ "description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n"
+ },
+ {
+ "name": "SyntheticDataGenerationResponse",
+ "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n"
+ },
+ {
+ "name": "RewardScoringRequest",
+ "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n"
+ },
+ {
+ "name": "RewardScoringResponse",
+ "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n"
}
],
"x-tagGroups": [
@@ -1537,7 +1774,9 @@
"name": "Operations",
"tags": [
"AgenticSystem",
- "Inference"
+ "Inference",
+ "RewardScoring",
+ "SyntheticDataGeneration"
]
},
{
@@ -1557,7 +1796,11 @@
"CompletionResponse",
"CompletionResponseStreamChunk",
"Message",
+ "RewardScoringRequest",
+ "RewardScoringResponse",
"ShieldConfig",
+ "SyntheticDataGenerationRequest",
+ "SyntheticDataGenerationResponse",
"URL"
]
}
diff --git a/source/openapi.yaml b/source/openapi.yaml
index f43164410..02e85beed 100644
--- a/source/openapi.yaml
+++ b/source/openapi.yaml
@@ -750,6 +750,70 @@ components:
- tool_calls
- tool_responses
type: object
+ RewardScoringRequest:
+ additionalProperties: false
+ properties:
+ model:
+ type: string
+ prompt_generations:
+ items:
+ additionalProperties: false
+ properties:
+ generation:
+ $ref: '#/components/schemas/Message'
+ message_history:
+ items:
+ $ref: '#/components/schemas/Message'
+ type: array
+ prompt:
+ $ref: '#/components/schemas/Message'
+ required:
+ - prompt
+ - message_history
+ - generation
+ type: object
+ type: array
+ required:
+ - prompt_generations
+ - model
+ title: Request to score a reward function. A list of prompts and a list of responses
+ per prompt.
+ type: object
+ RewardScoringResponse:
+ additionalProperties: false
+ properties:
+ scored_generations:
+ items:
+ additionalProperties: false
+ properties:
+ prompt_generation:
+ additionalProperties: false
+ properties:
+ generation:
+ $ref: '#/components/schemas/Message'
+ message_history:
+ items:
+ $ref: '#/components/schemas/Message'
+ type: array
+ prompt:
+ $ref: '#/components/schemas/Message'
+ required:
+ - prompt
+ - message_history
+ - generation
+ type: object
+ score:
+ type: number
+ required:
+ - prompt_generation
+ - score
+ type: object
+ type: array
+ required:
+ - scored_generations
+ title: Response from the reward scoring. Batch of (prompt, response, score)
+ tuples that pass the threshold.
+ type: object
ShieldConfig:
additionalProperties: false
properties:
@@ -774,6 +838,53 @@ components:
- shield_type
- params
type: object
+ SyntheticDataGenerationRequest:
+ additionalProperties: false
+ properties:
+ filtering_function:
+ default: none
+ enum:
+ - none
+ - random
+ - top_k
+ - top_p
+ - top_k_top_p
+ - sigmoid
+ title: The type of filtering function.
+ type: string
+ prompts:
+ items:
+ type: string
+ type: array
+ required:
+ - prompts
+ - filtering_function
+ title: Request to generate synthetic data. A small batch of prompts and a filtering
+ function
+ type: object
+ SyntheticDataGenerationResponse:
+ additionalProperties: false
+ properties:
+ statistics:
+ additionalProperties:
+ type: number
+ type: object
+ synthetic_data:
+ items:
+ maxItems: 3
+ minItems: 3
+ prefixItems:
+ - type: string
+ - type: string
+ - type: number
+ type: array
+ type: array
+ required:
+ - synthetic_data
+ - statistics
+ title: Response from the synthetic data generation. Batch of (prompt, response,
+ score) tuples that pass the threshold.
+ type: object
URL:
format: uri
pattern: ^(https?://|file://|data:)
@@ -878,13 +989,51 @@ paths:
description: Normal completion response. **OR** streamed completion response.
tags:
- Inference
+ /reward_scoring/score:
+ post:
+ parameters: []
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RewardScoringRequest'
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/RewardScoringResponse'
+ description: OK
+ tags:
+ - RewardScoring
+ /synthetic_data_generation/generate:
+ post:
+ parameters: []
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SyntheticDataGenerationRequest'
+ required: true
+ responses:
+ '200':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SyntheticDataGenerationResponse'
+ description: OK
+ tags:
+ - SyntheticDataGeneration
security:
- Default: []
servers:
- url: http://llama.meta.com
tags:
-- name: AgenticSystem
+- name: RewardScoring
- name: Inference
+- name: SyntheticDataGeneration
+- name: AgenticSystem
- description:
name: ShieldConfig
- description: '
name: CompletionResponseStreamChunk
+- description: 'Request to generate synthetic data. A small batch of prompts and a
+ filtering function
+
+
+ '
+ name: SyntheticDataGenerationRequest
+- description: 'Response from the synthetic data generation. Batch of (prompt, response,
+ score) tuples that pass the threshold.
+
+
+ '
+ name: SyntheticDataGenerationResponse
+- description: 'Request to score a reward function. A list of prompts and a list of
+ responses per prompt.
+
+
+ '
+ name: RewardScoringRequest
+- description: 'Response from the reward scoring. Batch of (prompt, response, score)
+ tuples that pass the threshold.
+
+
+ '
+ name: RewardScoringResponse
x-tagGroups:
- name: Operations
tags:
- AgenticSystem
- Inference
+ - RewardScoring
+ - SyntheticDataGeneration
- name: Types
tags:
- AgenticSystemCreateRequest
@@ -973,5 +1150,9 @@ x-tagGroups:
- CompletionResponse
- CompletionResponseStreamChunk
- Message
+ - RewardScoringRequest
+ - RewardScoringResponse
- ShieldConfig
+ - SyntheticDataGenerationRequest
+ - SyntheticDataGenerationResponse
- URL