mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
add reward scoring and synthetic data gen
This commit is contained in:
parent
beb2870750
commit
eb12bfbef0
1 changed files with 60 additions and 1 deletions
|
@ -203,7 +203,66 @@ class AgenticSystem(Protocol):
|
|||
) -> None: ...
|
||||
|
||||
|
||||
class LlamaStackEndpoints(Inference, AgenticSystem): ...
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringRequest:
|
||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||
|
||||
prompt_responses: List[Tuple(prompt:str, List[response:str])]
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringResponse:
|
||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
|
||||
|
||||
class RewardScoring(Protocol):
|
||||
@webmethod(route="/reward_scoring/score")
|
||||
def post_score(
|
||||
self,
|
||||
request: RewardScoringRequest,
|
||||
) -> Union[RewardScoringResponse]: ...
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function."""
|
||||
|
||||
none = "none"
|
||||
random = "random"
|
||||
top_k = "top_k"
|
||||
top_p = "top_p"
|
||||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class SyntheticDataGenerationRequest:
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function"""
|
||||
|
||||
prompts: List[str]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
reward_scoring: RewardScoring
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class SyntheticDataGenerationResponse:
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
|
||||
"""The actual synthetic data"""
|
||||
statistics: Dict[str, float]
|
||||
"""Statistics on how many prompts were generated and how many were filtered out"""
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic_data_generation/generate")
|
||||
def post_generate(
|
||||
self,
|
||||
request: SyntheticDataGenerationRequest,
|
||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
|
||||
|
||||
|
||||
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue