diff --git a/source/api_definitions.py b/source/api_definitions.py index d8c144415..a1ea720c8 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -218,25 +218,33 @@ class AgenticSystem(Protocol): @dataclass -class PromptGeneration: - # TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion +class KPromptGenerations: prompt: Message message_history: List[Message] - generation: Message + k_generations: List[Message] +@json_schema_type @dataclass -class ScoredPromptGeneration: - prompt_generation: PromptGeneration +class MessageScore: + """A single message and its score.""" + + message: Message score: float +@json_schema_type +@dataclass +class KScoredPromptGenerations: + prompt: Message + k_scored_generations: List[MessageScore] + @json_schema_type @dataclass class RewardScoringRequest: """Request to score a reward function. A list of prompts and a list of responses per prompt.""" - prompt_generations: List[PromptGeneration] + prompt_generations: List[KPromptGenerations] # TODO(ragho): create a RewardModel enum tye model: str @@ -247,7 +255,7 @@ class RewardScoringRequest: class RewardScoringResponse: """Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.""" - scored_generations: List[ScoredPromptGeneration] + scored_generations: List[KScoredPromptGenerations] class RewardScoring(Protocol): diff --git a/source/openapi.html b/source/openapi.html index 987897354..63994dd77 100644 --- a/source/openapi.html +++ b/source/openapi.html @@ -1960,15 +1960,18 @@ "$ref": "#/components/schemas/Message" } }, - "generation": { - "$ref": "#/components/schemas/Message" + "k_generations": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Message" + } } }, "additionalProperties": false, "required": [ "prompt", "message_history", - "generation" + "k_generations" ] } }, @@ -1983,46 +1986,49 @@ ], "title": "Request to score a reward function. A list of prompts and a list of responses per prompt." }, + "KScoredPromptGenerations": { + "type": "object", + "properties": { + "prompt": { + "$ref": "#/components/schemas/Message" + }, + "k_scored_generations": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageScore" + } + } + }, + "additionalProperties": false, + "required": [ + "prompt", + "k_scored_generations" + ] + }, + "MessageScore": { + "type": "object", + "properties": { + "message": { + "$ref": "#/components/schemas/Message" + }, + "score": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "message", + "score" + ], + "title": "A single message and its score." + }, "RewardScoringResponse": { "type": "object", "properties": { "scored_generations": { "type": "array", "items": { - "type": "object", - "properties": { - "prompt_generation": { - "type": "object", - "properties": { - "prompt": { - "$ref": "#/components/schemas/Message" - }, - "message_history": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - } - }, - "generation": { - "$ref": "#/components/schemas/Message" - } - }, - "additionalProperties": false, - "required": [ - "prompt", - "message_history", - "generation" - ] - }, - "score": { - "type": "number" - } - }, - "additionalProperties": false, - "required": [ - "prompt_generation", - "score" - ] + "$ref": "#/components/schemas/KScoredPromptGenerations" } } }, @@ -2306,11 +2312,14 @@ } ], "tags": [ + { + "name": "Inference" + }, { "name": "SyntheticDataGeneration" }, { - "name": "RewardScoring" + "name": "Datasets" }, { "name": "AgenticSystem" @@ -2319,10 +2328,7 @@ "name": "Finetuning" }, { - "name": "Inference" - }, - { - "name": "Datasets" + "name": "RewardScoring" }, { "name": "ShieldConfig", @@ -2416,6 +2422,14 @@ "name": "RewardScoringRequest", "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n" }, + { + "name": "KScoredPromptGenerations", + "description": "" + }, + { + "name": "MessageScore", + "description": "A single message and its score.\n\n" + }, { "name": "RewardScoringResponse", "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" @@ -2470,8 +2484,10 @@ "FinetuningJobLogStream", "FinetuningJobStatusResponse", "FinetuningTrainRequest", + "KScoredPromptGenerations", "LoraFinetuningConfig", "Message", + "MessageScore", "OptimizerConfig", "RewardScoringRequest", "RewardScoringResponse", diff --git a/source/openapi.yaml b/source/openapi.yaml index f8cae85b0..472b37196 100644 --- a/source/openapi.yaml +++ b/source/openapi.yaml @@ -900,6 +900,19 @@ components: - logger_config title: Request to finetune a model. type: object + KScoredPromptGenerations: + additionalProperties: false + properties: + k_scored_generations: + items: + $ref: '#/components/schemas/MessageScore' + type: array + prompt: + $ref: '#/components/schemas/Message' + required: + - prompt + - k_scored_generations + type: object LoraFinetuningConfig: additionalProperties: false properties: @@ -989,6 +1002,18 @@ components: - tool_calls - tool_responses type: object + MessageScore: + additionalProperties: false + properties: + message: + $ref: '#/components/schemas/Message' + score: + type: number + required: + - message + - score + title: A single message and its score. + type: object OptimizerConfig: additionalProperties: false properties: @@ -1019,8 +1044,10 @@ components: items: additionalProperties: false properties: - generation: - $ref: '#/components/schemas/Message' + k_generations: + items: + $ref: '#/components/schemas/Message' + type: array message_history: items: $ref: '#/components/schemas/Message' @@ -1030,7 +1057,7 @@ components: required: - prompt - message_history - - generation + - k_generations type: object type: array required: @@ -1044,30 +1071,7 @@ components: properties: scored_generations: items: - additionalProperties: false - properties: - prompt_generation: - additionalProperties: false - properties: - generation: - $ref: '#/components/schemas/Message' - message_history: - items: - $ref: '#/components/schemas/Message' - type: array - prompt: - $ref: '#/components/schemas/Message' - required: - - prompt - - message_history - - generation - type: object - score: - type: number - required: - - prompt_generation - - score - type: object + $ref: '#/components/schemas/KScoredPromptGenerations' type: array required: - scored_generations @@ -1408,12 +1412,12 @@ security: servers: - url: http://llama.meta.com tags: +- name: Inference - name: SyntheticDataGeneration -- name: RewardScoring +- name: Datasets - name: AgenticSystem - name: Finetuning -- name: Inference -- name: Datasets +- name: RewardScoring - description: name: ShieldConfig - description: ' name: RewardScoringRequest +- description: + name: KScoredPromptGenerations +- description: 'A single message and its score. + + + ' + name: MessageScore - description: 'Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold. @@ -1570,8 +1582,10 @@ x-tagGroups: - FinetuningJobLogStream - FinetuningJobStatusResponse - FinetuningTrainRequest + - KScoredPromptGenerations - LoraFinetuningConfig - Message + - MessageScore - OptimizerConfig - RewardScoringRequest - RewardScoringResponse