diff --git a/source/api_definitions.py b/source/api_definitions.py index e5697f86f..0f73cc8d3 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -283,23 +283,18 @@ class FilteringFunction(Enum): class SyntheticDataGenerationRequest: """Request to generate synthetic data. A small batch of prompts and a filtering function""" - prompts: List[str] + prompts: List[Message] filtering_function: FilteringFunction = FilteringFunction.none - - # TODO(ragho): fix this - # reward_scoring: RewardScoring + reward_scoring: Optional[RewardScoring] = None @json_schema_type @dataclass class SyntheticDataGenerationResponse: - """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" + """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" - synthetic_data: List[Tuple[str, str, float]] - - """The actual synthetic data""" - statistics: Dict[str, float] - """Statistics on how many prompts were generated and how many were filtered out""" + synthetic_data: List[KScoredPromptGenerations] + statistics: Optional[Dict[str, Any]] = None class SyntheticDataGeneration(Protocol):