add reward scoring and synthetic data gen

This commit is contained in:
Raghotham Murthy 2024-07-10 15:42:56 -07:00
parent beb2870750
commit eb12bfbef0

View file

@ -203,7 +203,66 @@ class AgenticSystem(Protocol):
) -> None: ...
class LlamaStackEndpoints(Inference, AgenticSystem): ...
@json_schema_type
@dataclass
class RewardScoringRequest:
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
prompt_responses: List[Tuple(prompt:str, List[response:str])]
@json_schema_type
@dataclass
class RewardScoringResponse:
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
def post_score(
self,
request: RewardScoringRequest,
) -> Union[RewardScoringResponse]: ...
class FilteringFunction(Enum):
"""The type of filtering function."""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
@dataclass
class SyntheticDataGenerationRequest:
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
prompts: List[str]
filtering_function: FilteringFunction = FilteringFunction.none
reward_scoring: RewardScoring
@json_schema_type
@dataclass
class SyntheticDataGenerationResponse:
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
"""The actual synthetic data"""
statistics: Dict[str, float]
"""Statistics on how many prompts were generated and how many were filtered out"""
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate")
def post_generate(
self,
request: SyntheticDataGenerationRequest,
) -> Union[SyntheticDataGenerationResponse]: ...
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
if __name__ == "__main__":