diff --git a/source/api_definitions.py b/source/api_definitions.py
index d8c144415..a1ea720c8 100644
--- a/source/api_definitions.py
+++ b/source/api_definitions.py
@@ -218,25 +218,33 @@ class AgenticSystem(Protocol):
@dataclass
-class PromptGeneration:
- # TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
+class KPromptGenerations:
prompt: Message
message_history: List[Message]
- generation: Message
+ k_generations: List[Message]
+@json_schema_type
@dataclass
-class ScoredPromptGeneration:
- prompt_generation: PromptGeneration
+class MessageScore:
+ """A single message and its score."""
+
+ message: Message
score: float
+@json_schema_type
+@dataclass
+class KScoredPromptGenerations:
+ prompt: Message
+ k_scored_generations: List[MessageScore]
+
@json_schema_type
@dataclass
class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
- prompt_generations: List[PromptGeneration]
+ prompt_generations: List[KPromptGenerations]
# TODO(ragho): create a RewardModel enum tye
model: str
@@ -247,7 +255,7 @@ class RewardScoringRequest:
class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
- scored_generations: List[ScoredPromptGeneration]
+ scored_generations: List[KScoredPromptGenerations]
class RewardScoring(Protocol):
diff --git a/source/openapi.html b/source/openapi.html
index 987897354..63994dd77 100644
--- a/source/openapi.html
+++ b/source/openapi.html
@@ -1960,15 +1960,18 @@
"$ref": "#/components/schemas/Message"
}
},
- "generation": {
- "$ref": "#/components/schemas/Message"
+ "k_generations": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/Message"
+ }
}
},
"additionalProperties": false,
"required": [
"prompt",
"message_history",
- "generation"
+ "k_generations"
]
}
},
@@ -1983,46 +1986,49 @@
],
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
},
+ "KScoredPromptGenerations": {
+ "type": "object",
+ "properties": {
+ "prompt": {
+ "$ref": "#/components/schemas/Message"
+ },
+ "k_scored_generations": {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/MessageScore"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "prompt",
+ "k_scored_generations"
+ ]
+ },
+ "MessageScore": {
+ "type": "object",
+ "properties": {
+ "message": {
+ "$ref": "#/components/schemas/Message"
+ },
+ "score": {
+ "type": "number"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "message",
+ "score"
+ ],
+ "title": "A single message and its score."
+ },
"RewardScoringResponse": {
"type": "object",
"properties": {
"scored_generations": {
"type": "array",
"items": {
- "type": "object",
- "properties": {
- "prompt_generation": {
- "type": "object",
- "properties": {
- "prompt": {
- "$ref": "#/components/schemas/Message"
- },
- "message_history": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/Message"
- }
- },
- "generation": {
- "$ref": "#/components/schemas/Message"
- }
- },
- "additionalProperties": false,
- "required": [
- "prompt",
- "message_history",
- "generation"
- ]
- },
- "score": {
- "type": "number"
- }
- },
- "additionalProperties": false,
- "required": [
- "prompt_generation",
- "score"
- ]
+ "$ref": "#/components/schemas/KScoredPromptGenerations"
}
}
},
@@ -2306,11 +2312,14 @@
}
],
"tags": [
+ {
+ "name": "Inference"
+ },
{
"name": "SyntheticDataGeneration"
},
{
- "name": "RewardScoring"
+ "name": "Datasets"
},
{
"name": "AgenticSystem"
@@ -2319,10 +2328,7 @@
"name": "Finetuning"
},
{
- "name": "Inference"
- },
- {
- "name": "Datasets"
+ "name": "RewardScoring"
},
{
"name": "ShieldConfig",
@@ -2416,6 +2422,14 @@
"name": "RewardScoringRequest",
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n"
},
+ {
+ "name": "KScoredPromptGenerations",
+ "description": ""
+ },
+ {
+ "name": "MessageScore",
+ "description": "A single message and its score.\n\n"
+ },
{
"name": "RewardScoringResponse",
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n"
@@ -2470,8 +2484,10 @@
"FinetuningJobLogStream",
"FinetuningJobStatusResponse",
"FinetuningTrainRequest",
+ "KScoredPromptGenerations",
"LoraFinetuningConfig",
"Message",
+ "MessageScore",
"OptimizerConfig",
"RewardScoringRequest",
"RewardScoringResponse",
diff --git a/source/openapi.yaml b/source/openapi.yaml
index f8cae85b0..472b37196 100644
--- a/source/openapi.yaml
+++ b/source/openapi.yaml
@@ -900,6 +900,19 @@ components:
- logger_config
title: Request to finetune a model.
type: object
+ KScoredPromptGenerations:
+ additionalProperties: false
+ properties:
+ k_scored_generations:
+ items:
+ $ref: '#/components/schemas/MessageScore'
+ type: array
+ prompt:
+ $ref: '#/components/schemas/Message'
+ required:
+ - prompt
+ - k_scored_generations
+ type: object
LoraFinetuningConfig:
additionalProperties: false
properties:
@@ -989,6 +1002,18 @@ components:
- tool_calls
- tool_responses
type: object
+ MessageScore:
+ additionalProperties: false
+ properties:
+ message:
+ $ref: '#/components/schemas/Message'
+ score:
+ type: number
+ required:
+ - message
+ - score
+ title: A single message and its score.
+ type: object
OptimizerConfig:
additionalProperties: false
properties:
@@ -1019,8 +1044,10 @@ components:
items:
additionalProperties: false
properties:
- generation:
- $ref: '#/components/schemas/Message'
+ k_generations:
+ items:
+ $ref: '#/components/schemas/Message'
+ type: array
message_history:
items:
$ref: '#/components/schemas/Message'
@@ -1030,7 +1057,7 @@ components:
required:
- prompt
- message_history
- - generation
+ - k_generations
type: object
type: array
required:
@@ -1044,30 +1071,7 @@ components:
properties:
scored_generations:
items:
- additionalProperties: false
- properties:
- prompt_generation:
- additionalProperties: false
- properties:
- generation:
- $ref: '#/components/schemas/Message'
- message_history:
- items:
- $ref: '#/components/schemas/Message'
- type: array
- prompt:
- $ref: '#/components/schemas/Message'
- required:
- - prompt
- - message_history
- - generation
- type: object
- score:
- type: number
- required:
- - prompt_generation
- - score
- type: object
+ $ref: '#/components/schemas/KScoredPromptGenerations'
type: array
required:
- scored_generations
@@ -1408,12 +1412,12 @@ security:
servers:
- url: http://llama.meta.com
tags:
+- name: Inference
- name: SyntheticDataGeneration
-- name: RewardScoring
+- name: Datasets
- name: AgenticSystem
- name: Finetuning
-- name: Inference
-- name: Datasets
+- name: RewardScoring
- description:
name: ShieldConfig
- description: '
name: RewardScoringRequest
+- description:
+ name: KScoredPromptGenerations
+- description: 'A single message and its score.
+
+
+ '
+ name: MessageScore
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
@@ -1570,8 +1582,10 @@ x-tagGroups:
- FinetuningJobLogStream
- FinetuningJobStatusResponse
- FinetuningTrainRequest
+ - KScoredPromptGenerations
- LoraFinetuningConfig
- Message
+ - MessageScore
- OptimizerConfig
- RewardScoringRequest
- RewardScoringResponse