diff --git a/source/api_definitions.py b/source/api_definitions.py index 2f6944007..266bbcf3c 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -203,7 +203,66 @@ class AgenticSystem(Protocol): ) -> None: ... -class LlamaStackEndpoints(Inference, AgenticSystem): ... +@json_schema_type +@dataclass +class RewardScoringRequest: + """Request to score a reward function. A list of prompts and a list of responses per prompt.""" + + prompt_responses: List[Tuple(prompt:str, List[response:str])] + +@json_schema_type +@dataclass +class RewardScoringResponse: + """Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.""" + + prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])] + +class RewardScoring(Protocol): + @webmethod(route="/reward_scoring/score") + def post_score( + self, + request: RewardScoringRequest, + ) -> Union[RewardScoringResponse]: ... + +class FilteringFunction(Enum): + """The type of filtering function.""" + + none = "none" + random = "random" + top_k = "top_k" + top_p = "top_p" + top_k_top_p = "top_k_top_p" + sigmoid = "sigmoid" + +@json_schema_type +@dataclass +class SyntheticDataGenerationRequest: + """Request to generate synthetic data. A small batch of prompts and a filtering function""" + + prompts: List[str] + filtering_function: FilteringFunction = FilteringFunction.none + reward_scoring: RewardScoring + +@json_schema_type +@dataclass +class SyntheticDataGenerationResponse: + """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" + + synthetic_data: List[Tuple(prompt:str, response:str, score:float)] + """The actual synthetic data""" + statistics: Dict[str, float] + """Statistics on how many prompts were generated and how many were filtered out""" + +class SyntheticDataGeneration(Protocol): + @webmethod(route="/synthetic_data_generation/generate") + def post_generate( + self, + request: SyntheticDataGenerationRequest, + ) -> Union[SyntheticDataGenerationResponse]: ... + + + +class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ... if __name__ == "__main__":