From 92e32f80ad4133ff025f7f0ba0561569c12d96ab Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 13:01:49 -0700 Subject: [PATCH] test_scoring --- .../scoring_functions/scoring_functions.py | 2 +- llama_stack/distribution/datatypes.py | 5 ++ llama_stack/distribution/distribution.py | 4 ++ llama_stack/distribution/resolver.py | 10 ++++ llama_stack/distribution/routers/__init__.py | 18 ++++++- llama_stack/distribution/routers/routers.py | 26 ++++++++++ .../distribution/routers/routing_tables.py | 20 +++++++- llama_stack/providers/datatypes.py | 13 ++++- .../impls/meta_reference/scoring/__init__.py | 18 +++++++ .../impls/meta_reference/scoring/config.py | 9 ++++ .../impls/meta_reference/scoring/scoring.py | 33 +++++++++++++ llama_stack/providers/registry/scoring.py | 24 +++++++++ .../providers/tests/scoring/__init__.py | 5 ++ .../scoring/provider_config_example.yaml | 9 ++++ .../providers/tests/scoring/test_scoring.py | 49 +++++++++++++++++++ 15 files changed, 240 insertions(+), 5 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/config.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring.py create mode 100644 llama_stack/providers/registry/scoring.py create mode 100644 llama_stack/providers/tests/scoring/__init__.py create mode 100644 llama_stack/providers/tests/scoring/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/scoring/test_scoring.py diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 1d71c51f3..025c62c94 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -84,5 +84,5 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/register", method="POST") async def register_scoring_function( - self, function: ScoringFunctionDefWithProvider + self, function_def: ScoringFunctionDefWithProvider ) -> None: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 10f78b78f..318809baf 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -15,10 +15,12 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -32,6 +34,7 @@ RoutableObject = Union[ ShieldDef, MemoryBankDef, DatasetDef, + ScoringFunctionDef, ] RoutableObjectWithProvider = Union[ @@ -39,6 +42,7 @@ RoutableObjectWithProvider = Union[ ShieldDefWithProvider, MemoryBankDefWithProvider, DatasetDefWithProvider, + ScoringFunctionDefWithProvider, ] RoutedProtocol = Union[ @@ -46,6 +50,7 @@ RoutedProtocol = Union[ Safety, Memory, DatasetIO, + Scoring, ] diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 53d544471..2149162a6 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.datasets, router_api=Api.datasetio, ), + AutoRoutedApiInfo( + routing_table_api=Api.scoring_functions, + router_api=Api.scoring, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 2e6b64a53..53da099ce 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -20,6 +20,8 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring +from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import ( @@ -42,6 +44,8 @@ def api_protocol_map() -> Dict[Api, Any]: Api.telemetry: Telemetry, Api.datasets: Datasets, Api.datasetio: DatasetIO, + Api.scoring_functions: ScoringFunctions, + Api.scoring: Scoring, } @@ -126,6 +130,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]: ) } + if info.router_api.value == "scoring": + print("SCORING API") + + # p = all_api_providers[api][provider.provider_type] + # p.deps__ = [a.value for a in p.api_dependencies] + providers_with_specs[info.router_api.value] = { "__builtin__": ProviderWithSpec( provider_id="__autorouted__", diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 4970e93e1..2cb6004b3 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -11,6 +11,7 @@ from .routing_tables import ( DatasetsRoutingTable, MemoryBanksRoutingTable, ModelsRoutingTable, + ScoringFunctionsRoutingTable, ShieldsRoutingTable, ) @@ -25,7 +26,9 @@ async def get_routing_table_impl( "models": ModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, + "scoring_functions": ScoringFunctionsRoutingTable, } + if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") @@ -35,17 +38,30 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: - from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter + from .routers import ( + DatasetIORouter, + InferenceRouter, + MemoryRouter, + SafetyRouter, + ScoringRouter, + ) api_to_routers = { "memory": MemoryRouter, "inference": InferenceRouter, "safety": SafetyRouter, "datasetio": DatasetIORouter, + "scoring": ScoringRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") + # api_with_deps = {"scoring"} + # if api.value in api_with_deps: + # impl = api_to_routers[api.value](routing_table, _deps) + # else: + # impl = api_to_routers[api.value](routing_table) + impl = api_to_routers[api.value](routing_table) await impl.initialize() return impl diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 31b8efa48..ab058ca8a 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -13,6 +13,7 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 class MemoryRouter(Memory): @@ -192,3 +193,28 @@ class DatasetIORouter(DatasetIO): page_token=page_token, filter_condition=filter_condition, ) + + +class ScoringRouter(Scoring): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: + # TODO + pass + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + pass diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index db0946d81..f13a046c0 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -218,7 +218,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def get_dataset( self, dataset_identifier: str ) -> Optional[DatasetDefWithProvider]: - return self.get_object_by_identifier(identifier) + return self.get_object_by_identifier(dataset_identifier) async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: await self.register_object(dataset_def) + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): + async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: + objects = [] + for objs in self.registry.values(): + objects.extend(objs) + return objects + + async def get_scoring_function( + self, name: str + ) -> Optional[ScoringFunctionDefWithProvider]: + return self.get_object_by_identifier(name) + + async def register_scoring_function( + self, function_def: ScoringFunctionDefWithProvider + ) -> None: + await self.register_object(function_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index d7e2d4d0c..903ff5438 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -11,10 +11,9 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef - from llama_stack.apis.memory_banks import MemoryBankDef - from llama_stack.apis.models import ModelDef +from llama_stack.apis.scoring_functions import ScoringFunctionDef from llama_stack.apis.shields import ShieldDef @@ -25,6 +24,7 @@ class Api(Enum): agents = "agents" memory = "memory" datasetio = "datasetio" + scoring = "scoring" telemetry = "telemetry" @@ -32,6 +32,7 @@ class Api(Enum): shields = "shields" memory_banks = "memory_banks" datasets = "datasets" + scoring_functions = "scoring_functions" # built-in API inspect = "inspect" @@ -61,6 +62,14 @@ class DatasetsProtocolPrivate(Protocol): async def register_datasets(self, dataset_def: DatasetDef) -> None: ... +class ScoringFunctionsProtocolPrivate(Protocol): + async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ... + + async def register_scoring_function( + self, function_def: ScoringFunctionDef + ) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py new file mode 100644 index 000000000..31c93faef --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import MetaReferenceScoringConfig + + +async def get_provider_impl( + config: MetaReferenceScoringConfig, + _deps, +): + from .scoring import MetaReferenceScoringImpl + + impl = MetaReferenceScoringImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/config.py b/llama_stack/providers/impls/meta_reference/scoring/config.py new file mode 100644 index 000000000..bd4dcb9f0 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class MetaReferenceScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py new file mode 100644 index 000000000..39ae40c13 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 + +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate + +from .config import MetaReferenceScoringConfig + + +class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): + def __init__(self, config: MetaReferenceScoringConfig) -> None: + self.config = config + self.dataset_infos = {} + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: + print("score_batch") + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + print("score") diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py new file mode 100644 index 000000000..69af25839 --- /dev/null +++ b/llama_stack/providers/registry/scoring.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.scoring, + provider_type="meta-reference", + pip_packages=[], + module="llama_stack.providers.impls.meta_reference.scoring", + config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig", + api_dependencies=[ + Api.datasetio, + ], + ), + ] diff --git a/llama_stack/providers/tests/scoring/__init__.py b/llama_stack/providers/tests/scoring/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/scoring/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml new file mode 100644 index 000000000..9a8895149 --- /dev/null +++ b/llama_stack/providers/tests/scoring/provider_config_example.yaml @@ -0,0 +1,9 @@ +providers: + datasetio: + - provider_id: test-meta + provider_type: meta-reference + config: {} + scoring: + - provider_id: test-meta + provider_type: meta-reference + config: {} diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py new file mode 100644 index 000000000..dccfc78fc --- /dev/null +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +import pytest_asyncio + +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def scoring_settings(): + impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio]) + return { + "scoring_impl": impls[Api.scoring], + "scoring_functions_impl": impls[Api.scoring_functions], + } + + +@pytest.mark.asyncio +async def test_scoring_functions_list(scoring_settings): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + response = await scoring_functions_impl.list_scoring_functions() + assert isinstance(response, list) + assert len(response) == 0