reward scoring

This commit is contained in:
Raghotham Murthy 2024-07-10 21:56:16 -07:00
parent 69ecf55de2
commit ebb59aa35f
3 changed files with 117 additions and 79 deletions

View file

@ -218,25 +218,33 @@ class AgenticSystem(Protocol):
@dataclass
class PromptGeneration:
# TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
class KPromptGenerations:
prompt: Message
message_history: List[Message]
generation: Message
k_generations: List[Message]
@json_schema_type
@dataclass
class ScoredPromptGeneration:
prompt_generation: PromptGeneration
class MessageScore:
"""A single message and its score."""
message: Message
score: float
@json_schema_type
@dataclass
class KScoredPromptGenerations:
prompt: Message
k_scored_generations: List[MessageScore]
@json_schema_type
@dataclass
class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
prompt_generations: List[PromptGeneration]
prompt_generations: List[KPromptGenerations]
# TODO(ragho): create a RewardModel enum tye
model: str
@ -247,7 +255,7 @@ class RewardScoringRequest:
class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
scored_generations: List[ScoredPromptGeneration]
scored_generations: List[KScoredPromptGenerations]
class RewardScoring(Protocol):

View file

@ -1960,15 +1960,18 @@
"$ref": "#/components/schemas/Message"
}
},
"generation": {
"k_generations": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
}
},
"additionalProperties": false,
"required": [
"prompt",
"message_history",
"generation"
"k_generations"
]
}
},
@ -1983,46 +1986,49 @@
],
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
},
"RewardScoringResponse": {
"type": "object",
"properties": {
"scored_generations": {
"type": "array",
"items": {
"type": "object",
"properties": {
"prompt_generation": {
"KScoredPromptGenerations": {
"type": "object",
"properties": {
"prompt": {
"$ref": "#/components/schemas/Message"
},
"message_history": {
"k_scored_generations": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
"$ref": "#/components/schemas/MessageScore"
}
},
"generation": {
"$ref": "#/components/schemas/Message"
}
},
"additionalProperties": false,
"required": [
"prompt",
"message_history",
"generation"
"k_scored_generations"
]
},
"MessageScore": {
"type": "object",
"properties": {
"message": {
"$ref": "#/components/schemas/Message"
},
"score": {
"type": "number"
}
},
"additionalProperties": false,
"required": [
"prompt_generation",
"message",
"score"
]
],
"title": "A single message and its score."
},
"RewardScoringResponse": {
"type": "object",
"properties": {
"scored_generations": {
"type": "array",
"items": {
"$ref": "#/components/schemas/KScoredPromptGenerations"
}
}
},
@ -2306,11 +2312,14 @@
}
],
"tags": [
{
"name": "Inference"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "RewardScoring"
"name": "Datasets"
},
{
"name": "AgenticSystem"
@ -2319,10 +2328,7 @@
"name": "Finetuning"
},
{
"name": "Inference"
},
{
"name": "Datasets"
"name": "RewardScoring"
},
{
"name": "ShieldConfig",
@ -2416,6 +2422,14 @@
"name": "RewardScoringRequest",
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
},
{
"name": "KScoredPromptGenerations",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/KScoredPromptGenerations\" />"
},
{
"name": "MessageScore",
"description": "A single message and its score.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/MessageScore\" />"
},
{
"name": "RewardScoringResponse",
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
@ -2470,8 +2484,10 @@
"FinetuningJobLogStream",
"FinetuningJobStatusResponse",
"FinetuningTrainRequest",
"KScoredPromptGenerations",
"LoraFinetuningConfig",
"Message",
"MessageScore",
"OptimizerConfig",
"RewardScoringRequest",
"RewardScoringResponse",

View file

@ -900,6 +900,19 @@ components:
- logger_config
title: Request to finetune a model.
type: object
KScoredPromptGenerations:
additionalProperties: false
properties:
k_scored_generations:
items:
$ref: '#/components/schemas/MessageScore'
type: array
prompt:
$ref: '#/components/schemas/Message'
required:
- prompt
- k_scored_generations
type: object
LoraFinetuningConfig:
additionalProperties: false
properties:
@ -989,6 +1002,18 @@ components:
- tool_calls
- tool_responses
type: object
MessageScore:
additionalProperties: false
properties:
message:
$ref: '#/components/schemas/Message'
score:
type: number
required:
- message
- score
title: A single message and its score.
type: object
OptimizerConfig:
additionalProperties: false
properties:
@ -1019,8 +1044,10 @@ components:
items:
additionalProperties: false
properties:
generation:
k_generations:
items:
$ref: '#/components/schemas/Message'
type: array
message_history:
items:
$ref: '#/components/schemas/Message'
@ -1030,7 +1057,7 @@ components:
required:
- prompt
- message_history
- generation
- k_generations
type: object
type: array
required:
@ -1044,30 +1071,7 @@ components:
properties:
scored_generations:
items:
additionalProperties: false
properties:
prompt_generation:
additionalProperties: false
properties:
generation:
$ref: '#/components/schemas/Message'
message_history:
items:
$ref: '#/components/schemas/Message'
type: array
prompt:
$ref: '#/components/schemas/Message'
required:
- prompt
- message_history
- generation
type: object
score:
type: number
required:
- prompt_generation
- score
type: object
$ref: '#/components/schemas/KScoredPromptGenerations'
type: array
required:
- scored_generations
@ -1408,12 +1412,12 @@ security:
servers:
- url: http://llama.meta.com
tags:
- name: Inference
- name: SyntheticDataGeneration
- name: RewardScoring
- name: Datasets
- name: AgenticSystem
- name: Finetuning
- name: Inference
- name: Datasets
- name: RewardScoring
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
@ -1522,6 +1526,14 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
name: RewardScoringRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/KScoredPromptGenerations"
/>
name: KScoredPromptGenerations
- description: 'A single message and its score.
<SchemaDefinition schemaRef="#/components/schemas/MessageScore" />'
name: MessageScore
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
@ -1570,8 +1582,10 @@ x-tagGroups:
- FinetuningJobLogStream
- FinetuningJobStatusResponse
- FinetuningTrainRequest
- KScoredPromptGenerations
- LoraFinetuningConfig
- Message
- MessageScore
- OptimizerConfig
- RewardScoringRequest
- RewardScoringResponse