From d9367054dfebc37f4807d548565a34d5d126dad4 Mon Sep 17 00:00:00 2001 From: Raghotham Murthy Date: Wed, 10 Jul 2024 22:58:29 -0700 Subject: [PATCH] sdg improvements --- source/api_definitions.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/source/api_definitions.py b/source/api_definitions.py index e5697f86f..0f73cc8d3 100644 --- a/source/api_definitions.py +++ b/source/api_definitions.py @@ -283,23 +283,18 @@ class FilteringFunction(Enum): class SyntheticDataGenerationRequest: """Request to generate synthetic data. A small batch of prompts and a filtering function""" - prompts: List[str] + prompts: List[Message] filtering_function: FilteringFunction = FilteringFunction.none - - # TODO(ragho): fix this - # reward_scoring: RewardScoring + reward_scoring: Optional[RewardScoring] = None @json_schema_type @dataclass class SyntheticDataGenerationResponse: - """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" + """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" - synthetic_data: List[Tuple[str, str, float]] - - """The actual synthetic data""" - statistics: Dict[str, float] - """Statistics on how many prompts were generated and how many were filtered out""" + synthetic_data: List[KScoredPromptGenerations] + statistics: Optional[Dict[str, Any]] = None class SyntheticDataGeneration(Protocol):