mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fixes to reward stuff
This commit is contained in:
parent
eb12bfbef0
commit
956f07b04c
3 changed files with 461 additions and 11 deletions
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
|
||||
|
||||
import yaml
|
||||
from agentic_system_types import (
|
||||
|
@ -17,8 +17,8 @@ from model_types import (
|
|||
Message,
|
||||
PretrainedModel,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ShieldConfig,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
|
@ -203,19 +203,37 @@ class AgenticSystem(Protocol):
|
|||
) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptGeneration:
|
||||
prompt: Message
|
||||
message_history: List[Message]
|
||||
generation: Message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredPromptGeneration:
|
||||
prompt_generation: PromptGeneration
|
||||
score: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringRequest:
|
||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||
|
||||
prompt_responses: List[Tuple(prompt:str, List[response:str])]
|
||||
prompt_generations: List[PromptGeneration]
|
||||
|
||||
# TODO(ragho): create a RewardModel enum tye
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringResponse:
|
||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
|
||||
scored_generations: List[ScoredPromptGeneration]
|
||||
|
||||
|
||||
class RewardScoring(Protocol):
|
||||
@webmethod(route="/reward_scoring/score")
|
||||
|
@ -224,6 +242,7 @@ class RewardScoring(Protocol):
|
|||
request: RewardScoringRequest,
|
||||
) -> Union[RewardScoringResponse]: ...
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function."""
|
||||
|
||||
|
@ -234,6 +253,7 @@ class FilteringFunction(Enum):
|
|||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class SyntheticDataGenerationRequest:
|
||||
|
@ -241,18 +261,23 @@ class SyntheticDataGenerationRequest:
|
|||
|
||||
prompts: List[str]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
reward_scoring: RewardScoring
|
||||
|
||||
# TODO(ragho): fix this
|
||||
# reward_scoring: RewardScoring
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class SyntheticDataGenerationResponse:
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
|
||||
synthetic_data: List[Tuple[str, str, float]]
|
||||
|
||||
"""The actual synthetic data"""
|
||||
statistics: Dict[str, float]
|
||||
"""Statistics on how many prompts were generated and how many were filtered out"""
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic_data_generation/generate")
|
||||
def post_generate(
|
||||
|
@ -261,8 +286,9 @@ class SyntheticDataGeneration(Protocol):
|
|||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
|
||||
|
||||
|
||||
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
|
||||
class LlamaStackEndpoints(
|
||||
Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
|
||||
): ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -191,6 +191,66 @@
|
|||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/synthetic_data_generation/generate": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SyntheticDataGenerationResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"SyntheticDataGeneration"
|
||||
],
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SyntheticDataGenerationRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/reward_scoring/score": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RewardScoringResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"RewardScoring"
|
||||
],
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RewardScoringRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
|
||||
|
@ -1451,6 +1511,161 @@
|
|||
"text_delta"
|
||||
],
|
||||
"title": "streamed completion response."
|
||||
},
|
||||
"SyntheticDataGenerationRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompts": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"filtering_function": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"none",
|
||||
"random",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"top_k_top_p",
|
||||
"sigmoid"
|
||||
],
|
||||
"title": "The type of filtering function.",
|
||||
"default": "none"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompts",
|
||||
"filtering_function"
|
||||
],
|
||||
"title": "Request to generate synthetic data. A small batch of prompts and a filtering function"
|
||||
},
|
||||
"SyntheticDataGenerationResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"synthetic_data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"prefixItems": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "number"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"statistics": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"synthetic_data",
|
||||
"statistics"
|
||||
],
|
||||
"title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
|
||||
},
|
||||
"RewardScoringRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt_generations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
},
|
||||
"message_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"generation": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt",
|
||||
"message_history",
|
||||
"generation"
|
||||
]
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt_generations",
|
||||
"model"
|
||||
],
|
||||
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
|
||||
},
|
||||
"RewardScoringResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scored_generations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt_generation": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
},
|
||||
"message_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"generation": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt",
|
||||
"message_history",
|
||||
"generation"
|
||||
]
|
||||
},
|
||||
"score": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"prompt_generation",
|
||||
"score"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"scored_generations"
|
||||
],
|
||||
"title": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."
|
||||
}
|
||||
},
|
||||
"responses": {}
|
||||
|
@ -1462,11 +1677,17 @@
|
|||
],
|
||||
"tags": [
|
||||
{
|
||||
"name": "AgenticSystem"
|
||||
"name": "RewardScoring"
|
||||
},
|
||||
{
|
||||
"name": "Inference"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGeneration"
|
||||
},
|
||||
{
|
||||
"name": "AgenticSystem"
|
||||
},
|
||||
{
|
||||
"name": "ShieldConfig",
|
||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldConfig\" />"
|
||||
|
@ -1530,6 +1751,22 @@
|
|||
{
|
||||
"name": "CompletionResponseStreamChunk",
|
||||
"description": "streamed completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/CompletionResponseStreamChunk\" />"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGenerationRequest",
|
||||
"description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n<SchemaDefinition schemaRef=\"#/components/schemas/SyntheticDataGenerationRequest\" />"
|
||||
},
|
||||
{
|
||||
"name": "SyntheticDataGenerationResponse",
|
||||
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/SyntheticDataGenerationResponse\" />"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoringRequest",
|
||||
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
|
||||
},
|
||||
{
|
||||
"name": "RewardScoringResponse",
|
||||
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
|
||||
}
|
||||
],
|
||||
"x-tagGroups": [
|
||||
|
@ -1537,7 +1774,9 @@
|
|||
"name": "Operations",
|
||||
"tags": [
|
||||
"AgenticSystem",
|
||||
"Inference"
|
||||
"Inference",
|
||||
"RewardScoring",
|
||||
"SyntheticDataGeneration"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1557,7 +1796,11 @@
|
|||
"CompletionResponse",
|
||||
"CompletionResponseStreamChunk",
|
||||
"Message",
|
||||
"RewardScoringRequest",
|
||||
"RewardScoringResponse",
|
||||
"ShieldConfig",
|
||||
"SyntheticDataGenerationRequest",
|
||||
"SyntheticDataGenerationResponse",
|
||||
"URL"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -750,6 +750,70 @@ components:
|
|||
- tool_calls
|
||||
- tool_responses
|
||||
type: object
|
||||
RewardScoringRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
model:
|
||||
type: string
|
||||
prompt_generations:
|
||||
items:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
generation:
|
||||
$ref: '#/components/schemas/Message'
|
||||
message_history:
|
||||
items:
|
||||
$ref: '#/components/schemas/Message'
|
||||
type: array
|
||||
prompt:
|
||||
$ref: '#/components/schemas/Message'
|
||||
required:
|
||||
- prompt
|
||||
- message_history
|
||||
- generation
|
||||
type: object
|
||||
type: array
|
||||
required:
|
||||
- prompt_generations
|
||||
- model
|
||||
title: Request to score a reward function. A list of prompts and a list of responses
|
||||
per prompt.
|
||||
type: object
|
||||
RewardScoringResponse:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
scored_generations:
|
||||
items:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
prompt_generation:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
generation:
|
||||
$ref: '#/components/schemas/Message'
|
||||
message_history:
|
||||
items:
|
||||
$ref: '#/components/schemas/Message'
|
||||
type: array
|
||||
prompt:
|
||||
$ref: '#/components/schemas/Message'
|
||||
required:
|
||||
- prompt
|
||||
- message_history
|
||||
- generation
|
||||
type: object
|
||||
score:
|
||||
type: number
|
||||
required:
|
||||
- prompt_generation
|
||||
- score
|
||||
type: object
|
||||
type: array
|
||||
required:
|
||||
- scored_generations
|
||||
title: Response from the reward scoring. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
type: object
|
||||
ShieldConfig:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
|
@ -774,6 +838,53 @@ components:
|
|||
- shield_type
|
||||
- params
|
||||
type: object
|
||||
SyntheticDataGenerationRequest:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
filtering_function:
|
||||
default: none
|
||||
enum:
|
||||
- none
|
||||
- random
|
||||
- top_k
|
||||
- top_p
|
||||
- top_k_top_p
|
||||
- sigmoid
|
||||
title: The type of filtering function.
|
||||
type: string
|
||||
prompts:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
required:
|
||||
- prompts
|
||||
- filtering_function
|
||||
title: Request to generate synthetic data. A small batch of prompts and a filtering
|
||||
function
|
||||
type: object
|
||||
SyntheticDataGenerationResponse:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
statistics:
|
||||
additionalProperties:
|
||||
type: number
|
||||
type: object
|
||||
synthetic_data:
|
||||
items:
|
||||
maxItems: 3
|
||||
minItems: 3
|
||||
prefixItems:
|
||||
- type: string
|
||||
- type: string
|
||||
- type: number
|
||||
type: array
|
||||
type: array
|
||||
required:
|
||||
- synthetic_data
|
||||
- statistics
|
||||
title: Response from the synthetic data generation. Batch of (prompt, response,
|
||||
score) tuples that pass the threshold.
|
||||
type: object
|
||||
URL:
|
||||
format: uri
|
||||
pattern: ^(https?://|file://|data:)
|
||||
|
@ -878,13 +989,51 @@ paths:
|
|||
description: Normal completion response. **OR** streamed completion response.
|
||||
tags:
|
||||
- Inference
|
||||
/reward_scoring/score:
|
||||
post:
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RewardScoringRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/RewardScoringResponse'
|
||||
description: OK
|
||||
tags:
|
||||
- RewardScoring
|
||||
/synthetic_data_generation/generate:
|
||||
post:
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SyntheticDataGenerationRequest'
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SyntheticDataGenerationResponse'
|
||||
description: OK
|
||||
tags:
|
||||
- SyntheticDataGeneration
|
||||
security:
|
||||
- Default: []
|
||||
servers:
|
||||
- url: http://llama.meta.com
|
||||
tags:
|
||||
- name: AgenticSystem
|
||||
- name: RewardScoring
|
||||
- name: Inference
|
||||
- name: SyntheticDataGeneration
|
||||
- name: AgenticSystem
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
|
||||
name: ShieldConfig
|
||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
|
||||
|
@ -952,11 +1101,39 @@ tags:
|
|||
<SchemaDefinition schemaRef="#/components/schemas/CompletionResponseStreamChunk"
|
||||
/>'
|
||||
name: CompletionResponseStreamChunk
|
||||
- description: 'Request to generate synthetic data. A small batch of prompts and a
|
||||
filtering function
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationRequest"
|
||||
/>'
|
||||
name: SyntheticDataGenerationRequest
|
||||
- description: 'Response from the synthetic data generation. Batch of (prompt, response,
|
||||
score) tuples that pass the threshold.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationResponse"
|
||||
/>'
|
||||
name: SyntheticDataGenerationResponse
|
||||
- description: 'Request to score a reward function. A list of prompts and a list of
|
||||
responses per prompt.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
|
||||
name: RewardScoringRequest
|
||||
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
|
||||
|
||||
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringResponse" />'
|
||||
name: RewardScoringResponse
|
||||
x-tagGroups:
|
||||
- name: Operations
|
||||
tags:
|
||||
- AgenticSystem
|
||||
- Inference
|
||||
- RewardScoring
|
||||
- SyntheticDataGeneration
|
||||
- name: Types
|
||||
tags:
|
||||
- AgenticSystemCreateRequest
|
||||
|
@ -973,5 +1150,9 @@ x-tagGroups:
|
|||
- CompletionResponse
|
||||
- CompletionResponseStreamChunk
|
||||
- Message
|
||||
- RewardScoringRequest
|
||||
- RewardScoringResponse
|
||||
- ShieldConfig
|
||||
- SyntheticDataGenerationRequest
|
||||
- SyntheticDataGenerationResponse
|
||||
- URL
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue