mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
add reward scoring and synthetic data gen
This commit is contained in:
parent
beb2870750
commit
eb12bfbef0
1 changed files with 60 additions and 1 deletions
|
@ -203,7 +203,66 @@ class AgenticSystem(Protocol):
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackEndpoints(Inference, AgenticSystem): ...
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class RewardScoringRequest:
|
||||||
|
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||||
|
|
||||||
|
prompt_responses: List[Tuple(prompt:str, List[response:str])]
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class RewardScoringResponse:
|
||||||
|
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
|
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
|
||||||
|
|
||||||
|
class RewardScoring(Protocol):
|
||||||
|
@webmethod(route="/reward_scoring/score")
|
||||||
|
def post_score(
|
||||||
|
self,
|
||||||
|
request: RewardScoringRequest,
|
||||||
|
) -> Union[RewardScoringResponse]: ...
|
||||||
|
|
||||||
|
class FilteringFunction(Enum):
|
||||||
|
"""The type of filtering function."""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
random = "random"
|
||||||
|
top_k = "top_k"
|
||||||
|
top_p = "top_p"
|
||||||
|
top_k_top_p = "top_k_top_p"
|
||||||
|
sigmoid = "sigmoid"
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class SyntheticDataGenerationRequest:
|
||||||
|
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||||
|
|
||||||
|
prompts: List[str]
|
||||||
|
filtering_function: FilteringFunction = FilteringFunction.none
|
||||||
|
reward_scoring: RewardScoring
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class SyntheticDataGenerationResponse:
|
||||||
|
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||||
|
|
||||||
|
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
|
||||||
|
"""The actual synthetic data"""
|
||||||
|
statistics: Dict[str, float]
|
||||||
|
"""Statistics on how many prompts were generated and how many were filtered out"""
|
||||||
|
|
||||||
|
class SyntheticDataGeneration(Protocol):
|
||||||
|
@webmethod(route="/synthetic_data_generation/generate")
|
||||||
|
def post_generate(
|
||||||
|
self,
|
||||||
|
request: SyntheticDataGenerationRequest,
|
||||||
|
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue