mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
fixes to reward stuff
This commit is contained in:
parent
eb12bfbef0
commit
956f07b04c
3 changed files with 461 additions and 11 deletions
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
|
||||
|
||||
import yaml
|
||||
from agentic_system_types import (
|
||||
|
@ -17,8 +17,8 @@ from model_types import (
|
|||
Message,
|
||||
PretrainedModel,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ShieldConfig,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
|
@ -203,19 +203,37 @@ class AgenticSystem(Protocol):
|
|||
) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptGeneration:
|
||||
prompt: Message
|
||||
message_history: List[Message]
|
||||
generation: Message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredPromptGeneration:
|
||||
prompt_generation: PromptGeneration
|
||||
score: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringRequest:
|
||||
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
|
||||
|
||||
prompt_responses: List[Tuple(prompt:str, List[response:str])]
|
||||
prompt_generations: List[PromptGeneration]
|
||||
|
||||
# TODO(ragho): create a RewardModel enum tye
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class RewardScoringResponse:
|
||||
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
prompt_responses: List[Tuple(prompt:str, List[Tuple(response:str, score:float)])]
|
||||
scored_generations: List[ScoredPromptGeneration]
|
||||
|
||||
|
||||
class RewardScoring(Protocol):
|
||||
@webmethod(route="/reward_scoring/score")
|
||||
|
@ -224,6 +242,7 @@ class RewardScoring(Protocol):
|
|||
request: RewardScoringRequest,
|
||||
) -> Union[RewardScoringResponse]: ...
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function."""
|
||||
|
||||
|
@ -234,6 +253,7 @@ class FilteringFunction(Enum):
|
|||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class SyntheticDataGenerationRequest:
|
||||
|
@ -241,18 +261,23 @@ class SyntheticDataGenerationRequest:
|
|||
|
||||
prompts: List[str]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
reward_scoring: RewardScoring
|
||||
|
||||
# TODO(ragho): fix this
|
||||
# reward_scoring: RewardScoring
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class SyntheticDataGenerationResponse:
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
|
||||
synthetic_data: List[Tuple(prompt:str, response:str, score:float)]
|
||||
synthetic_data: List[Tuple[str, str, float]]
|
||||
|
||||
"""The actual synthetic data"""
|
||||
statistics: Dict[str, float]
|
||||
"""Statistics on how many prompts were generated and how many were filtered out"""
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic_data_generation/generate")
|
||||
def post_generate(
|
||||
|
@ -261,8 +286,9 @@ class SyntheticDataGeneration(Protocol):
|
|||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
|
||||
|
||||
|
||||
class LlamaStackEndpoints(Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration): ...
|
||||
class LlamaStackEndpoints(
|
||||
Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
|
||||
): ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue