mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 11:50:41 +00:00
registry refactor
This commit is contained in:
parent
c50686b6fe
commit
95fd53d292
8 changed files with 39 additions and 71 deletions
|
|
@ -29,8 +29,7 @@ class GenerationOutput(BaseModel):
|
|||
@json_schema_type
|
||||
class PostprocessedGeneration(BaseModel):
|
||||
completion_message: str
|
||||
# structured transformed output from raw_completion_message to compute scorer metrics
|
||||
transformed_generation: Optional[Any] = None
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
# A sample (row) from dataset
|
||||
|
|
@ -70,8 +69,15 @@ TScorerInputSample = TypeVar("TScorerInputSample")
|
|||
|
||||
@json_schema_type
|
||||
class ScorerInputSample(DatasetSample):
|
||||
generation_output: PostprocessedGeneration
|
||||
expected_output: Union[str, List[str]]
|
||||
"""
|
||||
A dataset is required to have the following columns to be used for scoring:
|
||||
- generated_answer: str
|
||||
- expected_answer: Union[str, List[str]]
|
||||
"""
|
||||
|
||||
generated_answer: str
|
||||
expected_answer: Union[str, List[str]]
|
||||
generation_output: Optional[PostprocessedGeneration] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -217,18 +217,8 @@ class BaseScorer(ABC, Generic[TScorerInputSample]):
|
|||
|
||||
|
||||
class BaseTask(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
||||
generator: Optional[BaseGenerator] = None,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.generator_processor = generator_processor
|
||||
self.generator = generator
|
||||
self.scorer = scorer
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, *args, **kwargs) -> EvalResult:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue