mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
reward scoring
This commit is contained in:
parent
69ecf55de2
commit
ebb59aa35f
3 changed files with 117 additions and 79 deletions
|
@ -218,25 +218,33 @@ class AgenticSystem(Protocol):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptGeneration:
|
class KPromptGenerations:
|
||||||
# TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
|
|
||||||
prompt: Message
|
prompt: Message
|
||||||
message_history: List[Message]
|
message_history: List[Message]
|
||||||
generation: Message
|
k_generations: List[Message]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScoredPromptGeneration:
|
class MessageScore:
|
||||||
prompt_generation: PromptGeneration
|
"""A single message and its score."""
|
||||||
|
|
||||||
|
message: Message
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class KScoredPromptGenerations:
|
||||||
|
prompt: Message
|
||||||
|
k_scored_generations: List[MessageScore]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
@dataclass
|
@dataclass
|
||||||
class RewardScoringRequest:
|
class RewardScoringRequest:
|
||||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||||
|
|
||||||
prompt_generations: List[PromptGeneration]
|
prompt_generations: List[KPromptGenerations]
|
||||||
|
|
||||||
# TODO(ragho): create a RewardModel enum tye
|
# TODO(ragho): create a RewardModel enum tye
|
||||||
model: str
|
model: str
|
||||||
|
@ -247,7 +255,7 @@ class RewardScoringRequest:
|
||||||
class RewardScoringResponse:
|
class RewardScoringResponse:
|
||||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
scored_generations: List[ScoredPromptGeneration]
|
scored_generations: List[KScoredPromptGenerations]
|
||||||
|
|
||||||
|
|
||||||
class RewardScoring(Protocol):
|
class RewardScoring(Protocol):
|
||||||
|
|
|
@ -1960,15 +1960,18 @@
|
||||||
"$ref": "#/components/schemas/Message"
|
"$ref": "#/components/schemas/Message"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"generation": {
|
"k_generations": {
|
||||||
"$ref": "#/components/schemas/Message"
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Message"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"prompt",
|
"prompt",
|
||||||
"message_history",
|
"message_history",
|
||||||
"generation"
|
"k_generations"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1983,46 +1986,49 @@
|
||||||
],
|
],
|
||||||
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
|
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
|
||||||
},
|
},
|
||||||
|
"KScoredPromptGenerations": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"$ref": "#/components/schemas/Message"
|
||||||
|
},
|
||||||
|
"k_scored_generations": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/MessageScore"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"prompt",
|
||||||
|
"k_scored_generations"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"MessageScore": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {
|
||||||
|
"$ref": "#/components/schemas/Message"
|
||||||
|
},
|
||||||
|
"score": {
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"message",
|
||||||
|
"score"
|
||||||
|
],
|
||||||
|
"title": "A single message and its score."
|
||||||
|
},
|
||||||
"RewardScoringResponse": {
|
"RewardScoringResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"scored_generations": {
|
"scored_generations": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"$ref": "#/components/schemas/KScoredPromptGenerations"
|
||||||
"properties": {
|
|
||||||
"prompt_generation": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"prompt": {
|
|
||||||
"$ref": "#/components/schemas/Message"
|
|
||||||
},
|
|
||||||
"message_history": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/Message"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"generation": {
|
|
||||||
"$ref": "#/components/schemas/Message"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"prompt",
|
|
||||||
"message_history",
|
|
||||||
"generation"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"score": {
|
|
||||||
"type": "number"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"prompt_generation",
|
|
||||||
"score"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -2306,11 +2312,14 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tags": [
|
"tags": [
|
||||||
|
{
|
||||||
|
"name": "Inference"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "SyntheticDataGeneration"
|
"name": "SyntheticDataGeneration"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "RewardScoring"
|
"name": "Datasets"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "AgenticSystem"
|
"name": "AgenticSystem"
|
||||||
|
@ -2319,10 +2328,7 @@
|
||||||
"name": "Finetuning"
|
"name": "Finetuning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Inference"
|
"name": "RewardScoring"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Datasets"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "ShieldConfig",
|
"name": "ShieldConfig",
|
||||||
|
@ -2416,6 +2422,14 @@
|
||||||
"name": "RewardScoringRequest",
|
"name": "RewardScoringRequest",
|
||||||
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
|
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "KScoredPromptGenerations",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/KScoredPromptGenerations\" />"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "MessageScore",
|
||||||
|
"description": "A single message and its score.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/MessageScore\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "RewardScoringResponse",
|
"name": "RewardScoringResponse",
|
||||||
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
|
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
|
||||||
|
@ -2470,8 +2484,10 @@
|
||||||
"FinetuningJobLogStream",
|
"FinetuningJobLogStream",
|
||||||
"FinetuningJobStatusResponse",
|
"FinetuningJobStatusResponse",
|
||||||
"FinetuningTrainRequest",
|
"FinetuningTrainRequest",
|
||||||
|
"KScoredPromptGenerations",
|
||||||
"LoraFinetuningConfig",
|
"LoraFinetuningConfig",
|
||||||
"Message",
|
"Message",
|
||||||
|
"MessageScore",
|
||||||
"OptimizerConfig",
|
"OptimizerConfig",
|
||||||
"RewardScoringRequest",
|
"RewardScoringRequest",
|
||||||
"RewardScoringResponse",
|
"RewardScoringResponse",
|
||||||
|
|
|
@ -900,6 +900,19 @@ components:
|
||||||
- logger_config
|
- logger_config
|
||||||
title: Request to finetune a model.
|
title: Request to finetune a model.
|
||||||
type: object
|
type: object
|
||||||
|
KScoredPromptGenerations:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
k_scored_generations:
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/MessageScore'
|
||||||
|
type: array
|
||||||
|
prompt:
|
||||||
|
$ref: '#/components/schemas/Message'
|
||||||
|
required:
|
||||||
|
- prompt
|
||||||
|
- k_scored_generations
|
||||||
|
type: object
|
||||||
LoraFinetuningConfig:
|
LoraFinetuningConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -989,6 +1002,18 @@ components:
|
||||||
- tool_calls
|
- tool_calls
|
||||||
- tool_responses
|
- tool_responses
|
||||||
type: object
|
type: object
|
||||||
|
MessageScore:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
message:
|
||||||
|
$ref: '#/components/schemas/Message'
|
||||||
|
score:
|
||||||
|
type: number
|
||||||
|
required:
|
||||||
|
- message
|
||||||
|
- score
|
||||||
|
title: A single message and its score.
|
||||||
|
type: object
|
||||||
OptimizerConfig:
|
OptimizerConfig:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -1019,8 +1044,10 @@ components:
|
||||||
items:
|
items:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
generation:
|
k_generations:
|
||||||
$ref: '#/components/schemas/Message'
|
items:
|
||||||
|
$ref: '#/components/schemas/Message'
|
||||||
|
type: array
|
||||||
message_history:
|
message_history:
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/Message'
|
$ref: '#/components/schemas/Message'
|
||||||
|
@ -1030,7 +1057,7 @@ components:
|
||||||
required:
|
required:
|
||||||
- prompt
|
- prompt
|
||||||
- message_history
|
- message_history
|
||||||
- generation
|
- k_generations
|
||||||
type: object
|
type: object
|
||||||
type: array
|
type: array
|
||||||
required:
|
required:
|
||||||
|
@ -1044,30 +1071,7 @@ components:
|
||||||
properties:
|
properties:
|
||||||
scored_generations:
|
scored_generations:
|
||||||
items:
|
items:
|
||||||
additionalProperties: false
|
$ref: '#/components/schemas/KScoredPromptGenerations'
|
||||||
properties:
|
|
||||||
prompt_generation:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
generation:
|
|
||||||
$ref: '#/components/schemas/Message'
|
|
||||||
message_history:
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Message'
|
|
||||||
type: array
|
|
||||||
prompt:
|
|
||||||
$ref: '#/components/schemas/Message'
|
|
||||||
required:
|
|
||||||
- prompt
|
|
||||||
- message_history
|
|
||||||
- generation
|
|
||||||
type: object
|
|
||||||
score:
|
|
||||||
type: number
|
|
||||||
required:
|
|
||||||
- prompt_generation
|
|
||||||
- score
|
|
||||||
type: object
|
|
||||||
type: array
|
type: array
|
||||||
required:
|
required:
|
||||||
- scored_generations
|
- scored_generations
|
||||||
|
@ -1408,12 +1412,12 @@ security:
|
||||||
servers:
|
servers:
|
||||||
- url: http://llama.meta.com
|
- url: http://llama.meta.com
|
||||||
tags:
|
tags:
|
||||||
|
- name: Inference
|
||||||
- name: SyntheticDataGeneration
|
- name: SyntheticDataGeneration
|
||||||
- name: RewardScoring
|
- name: Datasets
|
||||||
- name: AgenticSystem
|
- name: AgenticSystem
|
||||||
- name: Finetuning
|
- name: Finetuning
|
||||||
- name: Inference
|
- name: RewardScoring
|
||||||
- name: Datasets
|
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
|
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
|
||||||
name: ShieldConfig
|
name: ShieldConfig
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
|
||||||
|
@ -1522,6 +1526,14 @@ tags:
|
||||||
|
|
||||||
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
|
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
|
||||||
name: RewardScoringRequest
|
name: RewardScoringRequest
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/KScoredPromptGenerations"
|
||||||
|
/>
|
||||||
|
name: KScoredPromptGenerations
|
||||||
|
- description: 'A single message and its score.
|
||||||
|
|
||||||
|
|
||||||
|
<SchemaDefinition schemaRef="#/components/schemas/MessageScore" />'
|
||||||
|
name: MessageScore
|
||||||
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
|
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
|
||||||
tuples that pass the threshold.
|
tuples that pass the threshold.
|
||||||
|
|
||||||
|
@ -1570,8 +1582,10 @@ x-tagGroups:
|
||||||
- FinetuningJobLogStream
|
- FinetuningJobLogStream
|
||||||
- FinetuningJobStatusResponse
|
- FinetuningJobStatusResponse
|
||||||
- FinetuningTrainRequest
|
- FinetuningTrainRequest
|
||||||
|
- KScoredPromptGenerations
|
||||||
- LoraFinetuningConfig
|
- LoraFinetuningConfig
|
||||||
- Message
|
- Message
|
||||||
|
- MessageScore
|
||||||
- OptimizerConfig
|
- OptimizerConfig
|
||||||
- RewardScoringRequest
|
- RewardScoringRequest
|
||||||
- RewardScoringResponse
|
- RewardScoringResponse
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue