reward scoring

This commit is contained in:
Raghotham Murthy 2024-07-10 21:56:16 -07:00
parent 69ecf55de2
commit ebb59aa35f
3 changed files with 117 additions and 79 deletions

View file

@ -218,25 +218,33 @@ class AgenticSystem(Protocol):
@dataclass @dataclass
class PromptGeneration: class KPromptGenerations:
# TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
prompt: Message prompt: Message
message_history: List[Message] message_history: List[Message]
generation: Message k_generations: List[Message]
@json_schema_type
@dataclass @dataclass
class ScoredPromptGeneration: class MessageScore:
prompt_generation: PromptGeneration """A single message and its score."""
message: Message
score: float score: float
@json_schema_type
@dataclass
class KScoredPromptGenerations:
prompt: Message
k_scored_generations: List[MessageScore]
@json_schema_type @json_schema_type
@dataclass @dataclass
class RewardScoringRequest: class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt.""" """Request to score a reward function. A list of prompts and a list of responses per prompt."""
prompt_generations: List[PromptGeneration] prompt_generations: List[KPromptGenerations]
# TODO(ragho): create a RewardModel enum tye # TODO(ragho): create a RewardModel enum tye
model: str model: str
@ -247,7 +255,7 @@ class RewardScoringRequest:
class RewardScoringResponse: class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.""" """Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
scored_generations: List[ScoredPromptGeneration] scored_generations: List[KScoredPromptGenerations]
class RewardScoring(Protocol): class RewardScoring(Protocol):

View file

@ -1960,15 +1960,18 @@
"$ref": "#/components/schemas/Message" "$ref": "#/components/schemas/Message"
} }
}, },
"generation": { "k_generations": {
"$ref": "#/components/schemas/Message" "type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"prompt", "prompt",
"message_history", "message_history",
"generation" "k_generations"
] ]
} }
}, },
@ -1983,46 +1986,49 @@
], ],
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt." "title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
}, },
"KScoredPromptGenerations": {
"type": "object",
"properties": {
"prompt": {
"$ref": "#/components/schemas/Message"
},
"k_scored_generations": {
"type": "array",
"items": {
"$ref": "#/components/schemas/MessageScore"
}
}
},
"additionalProperties": false,
"required": [
"prompt",
"k_scored_generations"
]
},
"MessageScore": {
"type": "object",
"properties": {
"message": {
"$ref": "#/components/schemas/Message"
},
"score": {
"type": "number"
}
},
"additionalProperties": false,
"required": [
"message",
"score"
],
"title": "A single message and its score."
},
"RewardScoringResponse": { "RewardScoringResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
"scored_generations": { "scored_generations": {
"type": "array", "type": "array",
"items": { "items": {
"type": "object", "$ref": "#/components/schemas/KScoredPromptGenerations"
"properties": {
"prompt_generation": {
"type": "object",
"properties": {
"prompt": {
"$ref": "#/components/schemas/Message"
},
"message_history": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
},
"generation": {
"$ref": "#/components/schemas/Message"
}
},
"additionalProperties": false,
"required": [
"prompt",
"message_history",
"generation"
]
},
"score": {
"type": "number"
}
},
"additionalProperties": false,
"required": [
"prompt_generation",
"score"
]
} }
} }
}, },
@ -2306,11 +2312,14 @@
} }
], ],
"tags": [ "tags": [
{
"name": "Inference"
},
{ {
"name": "SyntheticDataGeneration" "name": "SyntheticDataGeneration"
}, },
{ {
"name": "RewardScoring" "name": "Datasets"
}, },
{ {
"name": "AgenticSystem" "name": "AgenticSystem"
@ -2319,10 +2328,7 @@
"name": "Finetuning" "name": "Finetuning"
}, },
{ {
"name": "Inference" "name": "RewardScoring"
},
{
"name": "Datasets"
}, },
{ {
"name": "ShieldConfig", "name": "ShieldConfig",
@ -2416,6 +2422,14 @@
"name": "RewardScoringRequest", "name": "RewardScoringRequest",
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />" "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
}, },
{
"name": "KScoredPromptGenerations",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/KScoredPromptGenerations\" />"
},
{
"name": "MessageScore",
"description": "A single message and its score.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/MessageScore\" />"
},
{ {
"name": "RewardScoringResponse", "name": "RewardScoringResponse",
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />" "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
@ -2470,8 +2484,10 @@
"FinetuningJobLogStream", "FinetuningJobLogStream",
"FinetuningJobStatusResponse", "FinetuningJobStatusResponse",
"FinetuningTrainRequest", "FinetuningTrainRequest",
"KScoredPromptGenerations",
"LoraFinetuningConfig", "LoraFinetuningConfig",
"Message", "Message",
"MessageScore",
"OptimizerConfig", "OptimizerConfig",
"RewardScoringRequest", "RewardScoringRequest",
"RewardScoringResponse", "RewardScoringResponse",

View file

@ -900,6 +900,19 @@ components:
- logger_config - logger_config
title: Request to finetune a model. title: Request to finetune a model.
type: object type: object
KScoredPromptGenerations:
additionalProperties: false
properties:
k_scored_generations:
items:
$ref: '#/components/schemas/MessageScore'
type: array
prompt:
$ref: '#/components/schemas/Message'
required:
- prompt
- k_scored_generations
type: object
LoraFinetuningConfig: LoraFinetuningConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -989,6 +1002,18 @@ components:
- tool_calls - tool_calls
- tool_responses - tool_responses
type: object type: object
MessageScore:
additionalProperties: false
properties:
message:
$ref: '#/components/schemas/Message'
score:
type: number
required:
- message
- score
title: A single message and its score.
type: object
OptimizerConfig: OptimizerConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -1019,8 +1044,10 @@ components:
items: items:
additionalProperties: false additionalProperties: false
properties: properties:
generation: k_generations:
$ref: '#/components/schemas/Message' items:
$ref: '#/components/schemas/Message'
type: array
message_history: message_history:
items: items:
$ref: '#/components/schemas/Message' $ref: '#/components/schemas/Message'
@ -1030,7 +1057,7 @@ components:
required: required:
- prompt - prompt
- message_history - message_history
- generation - k_generations
type: object type: object
type: array type: array
required: required:
@ -1044,30 +1071,7 @@ components:
properties: properties:
scored_generations: scored_generations:
items: items:
additionalProperties: false $ref: '#/components/schemas/KScoredPromptGenerations'
properties:
prompt_generation:
additionalProperties: false
properties:
generation:
$ref: '#/components/schemas/Message'
message_history:
items:
$ref: '#/components/schemas/Message'
type: array
prompt:
$ref: '#/components/schemas/Message'
required:
- prompt
- message_history
- generation
type: object
score:
type: number
required:
- prompt_generation
- score
type: object
type: array type: array
required: required:
- scored_generations - scored_generations
@ -1408,12 +1412,12 @@ security:
servers: servers:
- url: http://llama.meta.com - url: http://llama.meta.com
tags: tags:
- name: Inference
- name: SyntheticDataGeneration - name: SyntheticDataGeneration
- name: RewardScoring - name: Datasets
- name: AgenticSystem - name: AgenticSystem
- name: Finetuning - name: Finetuning
- name: Inference - name: RewardScoring
- name: Datasets
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" /> - description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
@ -1522,6 +1526,14 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />' <SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
name: RewardScoringRequest name: RewardScoringRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/KScoredPromptGenerations"
/>
name: KScoredPromptGenerations
- description: 'A single message and its score.
<SchemaDefinition schemaRef="#/components/schemas/MessageScore" />'
name: MessageScore
- description: 'Response from the reward scoring. Batch of (prompt, response, score) - description: 'Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold. tuples that pass the threshold.
@ -1570,8 +1582,10 @@ x-tagGroups:
- FinetuningJobLogStream - FinetuningJobLogStream
- FinetuningJobStatusResponse - FinetuningJobStatusResponse
- FinetuningTrainRequest - FinetuningTrainRequest
- KScoredPromptGenerations
- LoraFinetuningConfig - LoraFinetuningConfig
- Message - Message
- MessageScore
- OptimizerConfig - OptimizerConfig
- RewardScoringRequest - RewardScoringRequest
- RewardScoringResponse - RewardScoringResponse