fixes to reward stuff

This commit is contained in:
Ashwin Bharambe 2024-07-10 19:22:33 -07:00
parent eb12bfbef0
commit 956f07b04c
3 changed files with 461 additions and 11 deletions

View file

@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Set, Union
from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
import yaml
from agentic_system_types import (
@ -17,8 +17,8 @@ from model_types import (
Message,
PretrainedModel,
SamplingParams,
StopReason,
ShieldConfig,
StopReason,
ToolCall,
ToolDefinition,
ToolResponse,
@ -203,19 +203,37 @@ class AgenticSystem(Protocol):
) -> None: ...
@dataclass
class PromptGeneration:
prompt: Message
message_history: List[Message]
generation: Message
@dataclass
class ScoredPromptGeneration:
prompt_generation: PromptGeneration
score: float
@json_schema_type
@dataclass
class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
prompt_responses: List[Tuple(prompt:str, List[response:str])]
prompt_generations: List[PromptGeneration]
# TODO(ragho): create a RewardModel enum tye
model: str
@json_schema_type
@dataclass
class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
scored_generations: List[ScoredPromptGeneration]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
@ -224,6 +242,7 @@ class RewardScoring(Protocol):
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...
class FilteringFunction(Enum):
"""The type of filtering function."""
@ -234,6 +253,7 @@ class FilteringFunction(Enum):
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
@dataclass
class SyntheticDataGenerationRequest:
@ -241,18 +261,23 @@ class SyntheticDataGenerationRequest:
prompts: List[str]
filtering_function: FilteringFunction = FilteringFunction.none
reward_scoring: RewardScoring
# TODO(ragho): fix this
# reward_scoring: RewardScoring
@json_schema_type
@dataclass
class SyntheticDataGenerationResponse:
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
synthetic_data: List[Tuple[str, str, float]]
"""The actual synthetic data"""
statistics: Dict[str, float]
"""Statistics on how many prompts were generated and how many were filtered out"""
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def post_generate(
@ -261,8 +286,9 @@ class SyntheticDataGeneration(Protocol):
) -> Union[SyntheticDataGenerationResponse]: ...
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
class LlamaStackEndpoints(
Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
): ...
if __name__ == "__main__":

View file

@ -191,6 +191,66 @@
"required": true
}
}
},
"/synthetic_data_generation/generate": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SyntheticDataGenerationResponse"
}
}
}
}
},
"tags": [
"SyntheticDataGeneration"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SyntheticDataGenerationRequest"
}
}
},
"required": true
}
}
},
"/reward_scoring/score": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RewardScoringResponse"
}
}
}
}
},
"tags": [
"RewardScoring"
],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RewardScoringRequest"
}
}
},
"required": true
}
}
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@ -1451,6 +1511,161 @@
"text_delta"
],
"title": "streamed completion response."
},
"SyntheticDataGenerationRequest": {
"type": "object",
"properties": {
"prompts": {
"type": "array",
"items": {
"type": "string"
}
},
"filtering_function": {
"type": "string",
"enum": [
"none",
"random",
"top_k",
"top_p",
"top_k_top_p",
"sigmoid"
],
"title": "The type of filtering function.",
"default": "none"
}
},
"additionalProperties": false,
"required": [
"prompts",
"filtering_function"
],
"title": "Request to generate synthetic data. A small batch of prompts and a filtering function"
},
"SyntheticDataGenerationResponse": {
"type": "object",
"properties": {
"synthetic_data": {
"type": "array",
"items": {
"type": "array",
"minItems": 3,
"maxItems": 3,
"prefixItems": [
{
"type": "string"
},
{
"type": "string"
},
{
"type": "number"
}
]
}
},
"statistics": {
"type": "object",
"additionalProperties": {
"type": "number"
}
}
},
"additionalProperties": false,
"required": [
"synthetic_data",
"statistics"
],
"title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."
},
"RewardScoringRequest": {
"type": "object",
"properties": {
"prompt_generations": {
"type": "array",
"items": {
"type": "object",
"properties": {
"prompt": {
"$ref": "#/components/schemas/Message"
},
"message_history": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
},
"generation": {
"$ref": "#/components/schemas/Message"
}
},
"additionalProperties": false,
"required": [
"prompt",
"message_history",
"generation"
]
}
},
"model": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"prompt_generations",
"model"
],
"title": "Request to score a reward function. A list of prompts and a list of responses per prompt."
},
"RewardScoringResponse": {
"type": "object",
"properties": {
"scored_generations": {
"type": "array",
"items": {
"type": "object",
"properties": {
"prompt_generation": {
"type": "object",
"properties": {
"prompt": {
"$ref": "#/components/schemas/Message"
},
"message_history": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
},
"generation": {
"$ref": "#/components/schemas/Message"
}
},
"additionalProperties": false,
"required": [
"prompt",
"message_history",
"generation"
]
},
"score": {
"type": "number"
}
},
"additionalProperties": false,
"required": [
"prompt_generation",
"score"
]
}
}
},
"additionalProperties": false,
"required": [
"scored_generations"
],
"title": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."
}
},
"responses": {}
@ -1462,11 +1677,17 @@
],
"tags": [
{
"name": "AgenticSystem"
"name": "RewardScoring"
},
{
"name": "Inference"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "AgenticSystem"
},
{
"name": "ShieldConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldConfig\" />"
@ -1530,6 +1751,22 @@
{
"name": "CompletionResponseStreamChunk",
"description": "streamed completion response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/CompletionResponseStreamChunk\" />"
},
{
"name": "SyntheticDataGenerationRequest",
"description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n<SchemaDefinition schemaRef=\"#/components/schemas/SyntheticDataGenerationRequest\" />"
},
{
"name": "SyntheticDataGenerationResponse",
"description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/SyntheticDataGenerationResponse\" />"
},
{
"name": "RewardScoringRequest",
"description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringRequest\" />"
},
{
"name": "RewardScoringResponse",
"description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/RewardScoringResponse\" />"
}
],
"x-tagGroups": [
@ -1537,7 +1774,9 @@
"name": "Operations",
"tags": [
"AgenticSystem",
"Inference"
"Inference",
"RewardScoring",
"SyntheticDataGeneration"
]
},
{
@ -1557,7 +1796,11 @@
"CompletionResponse",
"CompletionResponseStreamChunk",
"Message",
"RewardScoringRequest",
"RewardScoringResponse",
"ShieldConfig",
"SyntheticDataGenerationRequest",
"SyntheticDataGenerationResponse",
"URL"
]
}

View file

@ -750,6 +750,70 @@ components:
- tool_calls
- tool_responses
type: object
RewardScoringRequest:
additionalProperties: false
properties:
model:
type: string
prompt_generations:
items:
additionalProperties: false
properties:
generation:
$ref: '#/components/schemas/Message'
message_history:
items:
$ref: '#/components/schemas/Message'
type: array
prompt:
$ref: '#/components/schemas/Message'
required:
- prompt
- message_history
- generation
type: object
type: array
required:
- prompt_generations
- model
title: Request to score a reward function. A list of prompts and a list of responses
per prompt.
type: object
RewardScoringResponse:
additionalProperties: false
properties:
scored_generations:
items:
additionalProperties: false
properties:
prompt_generation:
additionalProperties: false
properties:
generation:
$ref: '#/components/schemas/Message'
message_history:
items:
$ref: '#/components/schemas/Message'
type: array
prompt:
$ref: '#/components/schemas/Message'
required:
- prompt
- message_history
- generation
type: object
score:
type: number
required:
- prompt_generation
- score
type: object
type: array
required:
- scored_generations
title: Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
type: object
ShieldConfig:
additionalProperties: false
properties:
@ -774,6 +838,53 @@ components:
- shield_type
- params
type: object
SyntheticDataGenerationRequest:
additionalProperties: false
properties:
filtering_function:
default: none
enum:
- none
- random
- top_k
- top_p
- top_k_top_p
- sigmoid
title: The type of filtering function.
type: string
prompts:
items:
type: string
type: array
required:
- prompts
- filtering_function
title: Request to generate synthetic data. A small batch of prompts and a filtering
function
type: object
SyntheticDataGenerationResponse:
additionalProperties: false
properties:
statistics:
additionalProperties:
type: number
type: object
synthetic_data:
items:
maxItems: 3
minItems: 3
prefixItems:
- type: string
- type: string
- type: number
type: array
type: array
required:
- synthetic_data
- statistics
title: Response from the synthetic data generation. Batch of (prompt, response,
score) tuples that pass the threshold.
type: object
URL:
format: uri
pattern: ^(https?://|file://|data:)
@ -878,13 +989,51 @@ paths:
description: Normal completion response. **OR** streamed completion response.
tags:
- Inference
/reward_scoring/score:
post:
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/RewardScoringRequest'
required: true
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/RewardScoringResponse'
description: OK
tags:
- RewardScoring
/synthetic_data_generation/generate:
post:
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/SyntheticDataGenerationRequest'
required: true
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/SyntheticDataGenerationResponse'
description: OK
tags:
- SyntheticDataGeneration
security:
- Default: []
servers:
- url: http://llama.meta.com
tags:
- name: AgenticSystem
- name: RewardScoring
- name: Inference
- name: SyntheticDataGeneration
- name: AgenticSystem
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldConfig" />
name: ShieldConfig
- description: <SchemaDefinition schemaRef="#/components/schemas/AgenticSystemCreateRequest"
@ -952,11 +1101,39 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/CompletionResponseStreamChunk"
/>'
name: CompletionResponseStreamChunk
- description: 'Request to generate synthetic data. A small batch of prompts and a
filtering function
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationRequest"
/>'
name: SyntheticDataGenerationRequest
- description: 'Response from the synthetic data generation. Batch of (prompt, response,
score) tuples that pass the threshold.
<SchemaDefinition schemaRef="#/components/schemas/SyntheticDataGenerationResponse"
/>'
name: SyntheticDataGenerationResponse
- description: 'Request to score a reward function. A list of prompts and a list of
responses per prompt.
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringRequest" />'
name: RewardScoringRequest
- description: 'Response from the reward scoring. Batch of (prompt, response, score)
tuples that pass the threshold.
<SchemaDefinition schemaRef="#/components/schemas/RewardScoringResponse" />'
name: RewardScoringResponse
x-tagGroups:
- name: Operations
tags:
- AgenticSystem
- Inference
- RewardScoring
- SyntheticDataGeneration
- name: Types
tags:
- AgenticSystemCreateRequest
@ -973,5 +1150,9 @@ x-tagGroups:
- CompletionResponse
- CompletionResponseStreamChunk
- Message
- RewardScoringRequest
- RewardScoringResponse
- ShieldConfig
- SyntheticDataGenerationRequest
- SyntheticDataGenerationResponse
- URL