fixes to reward stuff

This commit is contained in:
Ashwin Bharambe 2024-07-10 19:22:33 -07:00
parent eb12bfbef0
commit 956f07b04c
3 changed files with 461 additions and 11 deletions

View file

@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Set, Union
from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
import yaml
from agentic_system_types import (
@ -17,8 +17,8 @@ from model_types import (
Message,
PretrainedModel,
SamplingParams,
StopReason,
ShieldConfig,
StopReason,
ToolCall,
ToolDefinition,
ToolResponse,
@ -203,19 +203,37 @@ class AgenticSystem(Protocol):
) -> None: ...
@dataclass
class PromptGeneration:
prompt: Message
message_history: List[Message]
generation: Message
@dataclass
class ScoredPromptGeneration:
prompt_generation: PromptGeneration
score: float
@json_schema_type
@dataclass
class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
prompt_responses: List[Tuple(prompt:str, List[response:str])]
prompt_generations: List[PromptGeneration]
# TODO(ragho): create a RewardModel enum tye
model: str
@json_schema_type
@dataclass
class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
scored_generations: List[ScoredPromptGeneration]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
@ -224,6 +242,7 @@ class RewardScoring(Protocol):
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...
class FilteringFunction(Enum):
"""The type of filtering function."""
@ -234,6 +253,7 @@ class FilteringFunction(Enum):
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
@dataclass
class SyntheticDataGenerationRequest:
@ -241,18 +261,23 @@ class SyntheticDataGenerationRequest:
prompts: List[str]
filtering_function: FilteringFunction = FilteringFunction.none
reward_scoring: RewardScoring
# TODO(ragho): fix this
# reward_scoring: RewardScoring
@json_schema_type
@dataclass
class SyntheticDataGenerationResponse:
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
synthetic_data: List[Tuple[str, str, float]]
"""The actual synthetic data"""
statistics: Dict[str, float]
"""Statistics on how many prompts were generated and how many were filtered out"""
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def post_generate(
@ -261,8 +286,9 @@ class SyntheticDataGeneration(Protocol):
) -> Union[SyntheticDataGenerationResponse]: ...
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
class LlamaStackEndpoints(
Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
): ...
if __name__ == "__main__":