From 956f07b04c7cccb80f9d302125af935ea98abf39 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 10 Jul 2024 19:22:33 -0700 Subject: [PATCH] fixes to reward stuff --- source/api_definitions.py | 42 +++++-- source/openapi.html | 247 +++++++++++++++++++++++++++++++++++++- source/openapi.yaml | 183 +++++++++++++++++++++++++++- 3 files changed, 461 insertions(+), 11 deletions(-) diff --git a/source/api_definitions.py b/source/api_definitions.py index 266bbcf3c..4dded2166 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Set, Union +from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple import yaml from agentic_system_types import ( @@ -17,8 +17,8 @@ from model_types import ( Message, PretrainedModel, SamplingParams, - StopReason, ShieldConfig, + StopReason, ToolCall, ToolDefinition, ToolResponse, @@ -203,19 +203,37 @@ class AgenticSystem(Protocol): ) -> None: ... +@dataclass +class PromptGeneration: + prompt: Message + message_history: List[Message] + generation: Message + + +@dataclass +class ScoredPromptGeneration: + prompt_generation: PromptGeneration + score: float + + @json_schema_type @dataclass class RewardScoringRequest: """Request to score a reward function. A list of prompts and a list of responses per prompt.""" - prompt_responses: List[Tuple(prompt:str, List[response:str])] + prompt_generations: List[PromptGeneration] + + # TODO(ragho): create a RewardModel enum tye + model: str + @json_schema_type @dataclass class RewardScoringResponse: """Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.""" - prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])] + scored_generations: List[ScoredPromptGeneration] + class RewardScoring(Protocol): @webmethod(route="/reward_scoring/score") @@ -224,6 +242,7 @@ class RewardScoring(Protocol): request: RewardScoringRequest, ) -> Union[RewardScoringResponse]: ... + class FilteringFunction(Enum): """The type of filtering function.""" @@ -234,6 +253,7 @@ class FilteringFunction(Enum): top_k_top_p = "top_k_top_p" sigmoid = "sigmoid" + @json_schema_type @dataclass class SyntheticDataGenerationRequest: @@ -241,18 +261,23 @@ class SyntheticDataGenerationRequest: prompts: List[str] filtering_function: FilteringFunction = FilteringFunction.none - reward_scoring: RewardScoring + + # TODO(ragho): fix this + # reward_scoring: RewardScoring + @json_schema_type @dataclass class SyntheticDataGenerationResponse: """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" - synthetic_data: List[Tuple(prompt:str, response:str, score:float)] + synthetic_data: List[Tuple[str, str, float]] + """The actual synthetic data""" statistics: Dict[str, float] """Statistics on how many prompts were generated and how many were filtered out""" + class SyntheticDataGeneration(Protocol): @webmethod(route="/synthetic_data_generation/generate") def post_generate( @@ -261,8 +286,9 @@ class SyntheticDataGeneration(Protocol): ) -> Union[SyntheticDataGenerationResponse]: ... - -class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ... +class LlamaStackEndpoints( + Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration +): ... if __name__ == "__main__": diff --git a/source/openapi.html b/source/openapi.html index e52e21643..8f5f935d3 100644 --- a/source/openapi.html +++ b/source/openapi.html @@ -191,6 +191,66 @@ "required": true } } + }, + "/synthetic_data_generation/generate": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SyntheticDataGenerationResponse" + } + } + } + } + }, + "tags": [ + "SyntheticDataGeneration" + ], + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SyntheticDataGenerationRequest" + } + } + }, + "required": true + } + } + }, + "/reward_scoring/score": { + "post": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RewardScoringResponse" + } + } + } + } + }, + "tags": [ + "RewardScoring" + ], + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RewardScoringRequest" + } + } + }, + "required": true + } + } } }, "jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema", @@ -1451,6 +1511,161 @@ "text_delta" ], "title": "streamed completion response." + }, + "SyntheticDataGenerationRequest": { + "type": "object", + "properties": { + "prompts": { + "type": "array", + "items": { + "type": "string" + } + }, + "filtering_function": { + "type": "string", + "enum": [ + "none", + "random", + "top_k", + "top_p", + "top_k_top_p", + "sigmoid" + ], + "title": "The type of filtering function.", + "default": "none" + } + }, + "additionalProperties": false, + "required": [ + "prompts", + "filtering_function" + ], + "title": "Request to generate synthetic data. A small batch of prompts and a filtering function" + }, + "SyntheticDataGenerationResponse": { + "type": "object", + "properties": { + "synthetic_data": { + "type": "array", + "items": { + "type": "array", + "minItems": 3, + "maxItems": 3, + "prefixItems": [ + { + "type": "string" + }, + { + "type": "string" + }, + { + "type": "number" + } + ] + } + }, + "statistics": { + "type": "object", + "additionalProperties": { + "type": "number" + } + } + }, + "additionalProperties": false, + "required": [ + "synthetic_data", + "statistics" + ], + "title": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." + }, + "RewardScoringRequest": { + "type": "object", + "properties": { + "prompt_generations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "prompt": { + "$ref": "#/components/schemas/Message" + }, + "message_history": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Message" + } + }, + "generation": { + "$ref": "#/components/schemas/Message" + } + }, + "additionalProperties": false, + "required": [ + "prompt", + "message_history", + "generation" + ] + } + }, + "model": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "prompt_generations", + "model" + ], + "title": "Request to score a reward function. A list of prompts and a list of responses per prompt." + }, + "RewardScoringResponse": { + "type": "object", + "properties": { + "scored_generations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "prompt_generation": { + "type": "object", + "properties": { + "prompt": { + "$ref": "#/components/schemas/Message" + }, + "message_history": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Message" + } + }, + "generation": { + "$ref": "#/components/schemas/Message" + } + }, + "additionalProperties": false, + "required": [ + "prompt", + "message_history", + "generation" + ] + }, + "score": { + "type": "number" + } + }, + "additionalProperties": false, + "required": [ + "prompt_generation", + "score" + ] + } + } + }, + "additionalProperties": false, + "required": [ + "scored_generations" + ], + "title": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold." } }, "responses": {} @@ -1462,11 +1677,17 @@ ], "tags": [ { - "name": "AgenticSystem" + "name": "RewardScoring" }, { "name": "Inference" }, + { + "name": "SyntheticDataGeneration" + }, + { + "name": "AgenticSystem" + }, { "name": "ShieldConfig", "description": "" @@ -1530,6 +1751,22 @@ { "name": "CompletionResponseStreamChunk", "description": "streamed completion response.\n\n" + }, + { + "name": "SyntheticDataGenerationRequest", + "description": "Request to generate synthetic data. A small batch of prompts and a filtering function\n\n" + }, + { + "name": "SyntheticDataGenerationResponse", + "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" + }, + { + "name": "RewardScoringRequest", + "description": "Request to score a reward function. A list of prompts and a list of responses per prompt.\n\n" + }, + { + "name": "RewardScoringResponse", + "description": "Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.\n\n" } ], "x-tagGroups": [ @@ -1537,7 +1774,9 @@ "name": "Operations", "tags": [ "AgenticSystem", - "Inference" + "Inference", + "RewardScoring", + "SyntheticDataGeneration" ] }, { @@ -1557,7 +1796,11 @@ "CompletionResponse", "CompletionResponseStreamChunk", "Message", + "RewardScoringRequest", + "RewardScoringResponse", "ShieldConfig", + "SyntheticDataGenerationRequest", + "SyntheticDataGenerationResponse", "URL" ] } diff --git a/source/openapi.yaml b/source/openapi.yaml index f43164410..02e85beed 100644 --- a/source/openapi.yaml +++ b/source/openapi.yaml @@ -750,6 +750,70 @@ components: - tool_calls - tool_responses type: object + RewardScoringRequest: + additionalProperties: false + properties: + model: + type: string + prompt_generations: + items: + additionalProperties: false + properties: + generation: + $ref: '#/components/schemas/Message' + message_history: + items: + $ref: '#/components/schemas/Message' + type: array + prompt: + $ref: '#/components/schemas/Message' + required: + - prompt + - message_history + - generation + type: object + type: array + required: + - prompt_generations + - model + title: Request to score a reward function. A list of prompts and a list of responses + per prompt. + type: object + RewardScoringResponse: + additionalProperties: false + properties: + scored_generations: + items: + additionalProperties: false + properties: + prompt_generation: + additionalProperties: false + properties: + generation: + $ref: '#/components/schemas/Message' + message_history: + items: + $ref: '#/components/schemas/Message' + type: array + prompt: + $ref: '#/components/schemas/Message' + required: + - prompt + - message_history + - generation + type: object + score: + type: number + required: + - prompt_generation + - score + type: object + type: array + required: + - scored_generations + title: Response from the reward scoring. Batch of (prompt, response, score) + tuples that pass the threshold. + type: object ShieldConfig: additionalProperties: false properties: @@ -774,6 +838,53 @@ components: - shield_type - params type: object + SyntheticDataGenerationRequest: + additionalProperties: false + properties: + filtering_function: + default: none + enum: + - none + - random + - top_k + - top_p + - top_k_top_p + - sigmoid + title: The type of filtering function. + type: string + prompts: + items: + type: string + type: array + required: + - prompts + - filtering_function + title: Request to generate synthetic data. A small batch of prompts and a filtering + function + type: object + SyntheticDataGenerationResponse: + additionalProperties: false + properties: + statistics: + additionalProperties: + type: number + type: object + synthetic_data: + items: + maxItems: 3 + minItems: 3 + prefixItems: + - type: string + - type: string + - type: number + type: array + type: array + required: + - synthetic_data + - statistics + title: Response from the synthetic data generation. Batch of (prompt, response, + score) tuples that pass the threshold. + type: object URL: format: uri pattern: ^(https?://|file://|data:) @@ -878,13 +989,51 @@ paths: description: Normal completion response. **OR** streamed completion response. tags: - Inference + /reward_scoring/score: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RewardScoringRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/RewardScoringResponse' + description: OK + tags: + - RewardScoring + /synthetic_data_generation/generate: + post: + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SyntheticDataGenerationRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/SyntheticDataGenerationResponse' + description: OK + tags: + - SyntheticDataGeneration security: - Default: [] servers: - url: http://llama.meta.com tags: -- name: AgenticSystem +- name: RewardScoring - name: Inference +- name: SyntheticDataGeneration +- name: AgenticSystem - description: name: ShieldConfig - description: ' name: CompletionResponseStreamChunk +- description: 'Request to generate synthetic data. A small batch of prompts and a + filtering function + + + ' + name: SyntheticDataGenerationRequest +- description: 'Response from the synthetic data generation. Batch of (prompt, response, + score) tuples that pass the threshold. + + + ' + name: SyntheticDataGenerationResponse +- description: 'Request to score a reward function. A list of prompts and a list of + responses per prompt. + + + ' + name: RewardScoringRequest +- description: 'Response from the reward scoring. Batch of (prompt, response, score) + tuples that pass the threshold. + + + ' + name: RewardScoringResponse x-tagGroups: - name: Operations tags: - AgenticSystem - Inference + - RewardScoring + - SyntheticDataGeneration - name: Types tags: - AgenticSystemCreateRequest @@ -973,5 +1150,9 @@ x-tagGroups: - CompletionResponse - CompletionResponseStreamChunk - Message + - RewardScoringRequest + - RewardScoringResponse - ShieldConfig + - SyntheticDataGenerationRequest + - SyntheticDataGenerationResponse - URL