mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
reward scoring
This commit is contained in:
parent
69ecf55de2
commit
ebb59aa35f
3 changed files with 117 additions and 79 deletions
|
@ -218,25 +218,33 @@ class AgenticSystem(Protocol):
|
|||
|
||||
|
||||
@dataclass
|
||||
class PromptGeneration:
|
||||
# TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
|
||||
class KPromptGenerations:
|
||||
prompt: Message
|
||||
message_history: List[Message]
|
||||
generation: Message
|
||||
k_generations: List[Message]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ScoredPromptGeneration:
|
||||
prompt_generation: PromptGeneration
|
||||
class MessageScore:
|
||||
"""A single message and its score."""
|
||||
|
||||
message: Message
|
||||
score: float
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class KScoredPromptGenerations:
|
||||
prompt: Message
|
||||
k_scored_generations: List[MessageScore]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringRequest:
|
||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||
|
||||
prompt_generations: List[PromptGeneration]
|
||||
prompt_generations: List[KPromptGenerations]
|
||||
|
||||
# TODO(ragho): create a RewardModel enum tye
|
||||
model: str
|
||||
|
@ -247,7 +255,7 @@ class RewardScoringRequest:
|
|||
class RewardScoringResponse:
|
||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
scored_generations: List[ScoredPromptGeneration]
|
||||
scored_generations: List[KScoredPromptGenerations]
|
||||
|
||||
|
||||
class RewardScoring(Protocol):
|
||||
|
|
|
@ -1960,15 +1960,18 @@
|
|||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"generation": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
"k_generations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt",
|
||||
"message_history",
|
||||
"generation"
|
||||
"k_generations"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
@ -1983,46 +1986,49 @@
|
|||
],
|
||||
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
|
||||
},
|
||||
"KScoredPromptGenerations": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
},
|
||||
"k_scored_generations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/MessageScore"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt",
|
||||
"k_scored_generations"
|
||||
]
|
||||
},
|
||||
"MessageScore": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
},
|
||||
"score": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"message",
|
||||
"score"
|
||||
],
|
||||
"title": "A single message and its score."
|
||||
},
|
||||
"RewardScoringResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scored_generations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt_generation": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
},
|
||||
"message_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"generation": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt",
|
||||
"message_history",
|
||||
"generation"
|
||||
]
|
||||
},
|
||||
"score": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt_generation",
|
||||
"score"
|
||||
]
|
||||
"$ref": "#/components/schemas/KScoredPromptGenerations"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -2306,11 +2312,14 @@
|
|||
}
|
||||
],
|
||||
"tags": [
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoring"
|
||||
"name": "Datasets"
|
||||
},
|
||||
{
|
||||
"name": "AgenticSystem"
|
||||
|
@ -2319,10 +2328,7 @@
|
|||
"name": "Finetuning"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "Datasets"
|
||||
"name": "RewardScoring"
|
||||
},
|
||||
{
|
||||
"name": "ShieldConfig",
|
||||
|
@ -2416,6 +2422,14 @@
|
|||
"name": "RewardScoringRequest",
|
||||
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
|
||||
},
|
||||
{
|
||||
"name": "KScoredPromptGenerations",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/KScoredPromptGenerations\" />"
|
||||
},
|
||||
{
|
||||
"name": "MessageScore",
|
||||
"description": "A single message and its score.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/MessageScore\" />"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoringResponse",
|
||||
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
|
||||
|
@ -2470,8 +2484,10 @@
|
|||
"FinetuningJobLogStream",
|
||||
"FinetuningJobStatusResponse",
|
||||
"FinetuningTrainRequest",
|
||||
"KScoredPromptGenerations",
|
||||
"LoraFinetuningConfig",
|
||||
"Message",
|
||||
"MessageScore",
|
||||
"OptimizerConfig",
|
||||
"RewardScoringRequest",
|
||||
"RewardScoringResponse",
|
||||
|
|
|
@ -900,6 +900,19 @@ components:
|
|||
- logger_config
|
||||
title: Request to finetune a model.
|
||||
type: object
|
||||
KScoredPromptGenerations:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
k_scored_generations:
|
||||
items:
|
||||
$ref: '#/components/schemas/MessageScore'
|
||||
type: array
|
||||
prompt:
|
||||
$ref: '#/components/schemas/Message'
|
||||
required:
|
||||
- prompt
|
||||
- k_scored_generations
|
||||
type: object
|
||||
LoraFinetuningConfig:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -989,6 +1002,18 @@ components:
|
|||
- tool_calls
|
||||
- tool_responses
|
||||
type: object
|
||||
MessageScore:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
message:
|
||||
$ref: '#/components/schemas/Message'
|
||||
score:
|
||||
type: number
|
||||
required:
|
||||
- message
|
||||
- score
|
||||
title: A single message and its score.
|
||||
type: object
|
||||
OptimizerConfig:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -1019,8 +1044,10 @@ components:
|
|||
items:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
generation:
|
||||
$ref: '#/components/schemas/Message'
|
||||
k_generations:
|
||||
items:
|
||||
$ref: '#/components/schemas/Message'
|
||||
type: array
|
||||
message_history:
|
||||
items:
|
||||
$ref: '#/components/schemas/Message'
|
||||
|
@ -1030,7 +1057,7 @@ components:
|
|||
required:
|
||||
- prompt
|
||||
- message_history
|
||||
- generation
|
||||
- k_generations
|
||||
type: object
|
||||
type: array
|
||||
required:
|
||||
|
@ -1044,30 +1071,7 @@ components:
|
|||
properties:
|
||||
scored_generations:
|
||||
items:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
prompt_generation:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
generation:
|
||||
$ref: '#/components/schemas/Message'
|
||||
message_history:
|
||||
items:
|
||||
$ref: '#/components/schemas/Message'
|
||||
type: array
|
||||
prompt:
|
||||
$ref: '#/components/schemas/Message'
|
||||
required:
|
||||
- prompt
|
||||
- message_history
|
||||
- generation
|
||||
type: object
|
||||
score:
|
||||
type: number
|
||||
required:
|
||||
- prompt_generation
|
||||
- score
|
||||
type: object
|
||||
$ref: '#/components/schemas/KScoredPromptGenerations'
|
||||
type: array
|
||||
required:
|
||||
- scored_generations
|
||||
|
@ -1408,12 +1412,12 @@ security:
|
|||
servers:
|
||||
- url: http://llama.meta.com
|
||||
tags:
|
||||
- name: Inference
|
||||
- name: SyntheticDataGeneration
|
||||
- name: RewardScoring
|
||||
- name: Datasets
|
||||
- name: AgenticSystem
|
||||
- name: Finetuning
|
||||
- name: Inference
|
||||
- name: Datasets
|
||||
- name: RewardScoring
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
|
||||
name: ShieldConfig
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
|
||||
|
@ -1522,6 +1526,14 @@ tags:
|
|||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
|
||||
name: RewardScoringRequest
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/KScoredPromptGenerations"
|
||||
/>
|
||||
name: KScoredPromptGenerations
|
||||
- description: 'A single message and its score.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/MessageScore" />'
|
||||
name: MessageScore
|
||||
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
|
||||
|
@ -1570,8 +1582,10 @@ x-tagGroups:
|
|||
- FinetuningJobLogStream
|
||||
- FinetuningJobStatusResponse
|
||||
- FinetuningTrainRequest
|
||||
- KScoredPromptGenerations
|
||||
- LoraFinetuningConfig
|
||||
- Message
|
||||
- MessageScore
|
||||
- OptimizerConfig
|
||||
- RewardScoringRequest
|
||||
- RewardScoringResponse
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue