From a4f5f1f8902e97529ed2341073684f8074819a55 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 22 Oct 2024 09:30:08 -0700 Subject: [PATCH] move jobs, fix errors --- llama_stack/apis/common/job_types.py | 12 ++++++++++++ llama_stack/apis/datasets/datasets.py | 7 +------ llama_stack/apis/eval/eval.py | 8 +++----- .../apis/scoring_functions/scoring_functions.py | 13 +++++++------ 4 files changed, 23 insertions(+), 17 deletions(-) create mode 100644 llama_stack/apis/common/job_types.py diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py new file mode 100644 index 000000000..ab203ebb8 --- /dev/null +++ b/llama_stack/apis/common/job_types.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + + +@json_schema_type +class Job(BaseModel): + job_id: str diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 9bb7e6f7f..9160e1e13 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -15,17 +15,12 @@ from pydantic import BaseModel, Field from llama_stack.apis.common.type_system import ParamType -@json_schema_type -class DatasetSchema(BaseModel): - columns: Dict[str, ParamType] - - @json_schema_type class DatasetDef(BaseModel): identifier: str = Field( description="A unique name for the dataset", ) - dataset_schema: DatasetSchema = Field( + columns_schema: Dict[str, ParamType] = Field( description="The schema definition for this dataset", ) url: URL diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 5fcd267d9..a97af1fc0 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -11,6 +11,9 @@ from typing_extensions import Annotated from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.schema_utils import json_schema_type, webmethod from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.agents import AgentConfig +from llama_stack.apis.common.job_types import Job +from llama_stack.apis.scoring import * # noqa: F403 @json_schema_type @@ -32,11 +35,6 @@ EvalCandidate = Annotated[ ] -@json_schema_type -class Job(BaseModel): - job_id: str - - @json_schema_type class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index a5aca34fe..1d71c51f3 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -22,6 +22,7 @@ from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType +@json_schema_type class Parameter(BaseModel): name: str type: ParamType @@ -32,6 +33,7 @@ class Parameter(BaseModel): # with standard metrics so they can be rolled up? +@json_schema_type class CommonDef(BaseModel): name: str description: Optional[str] = None @@ -39,8 +41,11 @@ class CommonDef(BaseModel): default_factory=dict, description="Any additional metadata for this definition", ) + # Hack: same with memory_banks for union defs + provider_id: str = "" +@json_schema_type class DeterministicFunctionDef(CommonDef): type: Literal["deterministic"] = "deterministic" parameters: List[Parameter] = Field( @@ -52,6 +57,7 @@ class DeterministicFunctionDef(CommonDef): # We can optionally add information here to support packaging of code, etc. +@json_schema_type class LLMJudgeFunctionDef(CommonDef): type: Literal["judge"] = "judge" model: str = Field( @@ -63,12 +69,7 @@ ScoringFunctionDef = Annotated[ Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type") ] - -@json_schema_type -class ScoringFunctionDefWithProvider(ScoringFunctionDef): - provider_id: str = Field( - description="The provider ID for this scoring function", - ) +ScoringFunctionDefWithProvider = ScoringFunctionDef @runtime_checkable