mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
remove evals from top-level
This commit is contained in:
parent
a475d72155
commit
86486a94ce
10 changed files with 166 additions and 263 deletions
|
@ -11,12 +11,9 @@ from pydantic import BaseModel, Field
|
||||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||||
from llama_stack.apis.eval import Eval
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.models import Model, ModelInput
|
from llama_stack.apis.models import Model, ModelInput
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
|
||||||
from llama_stack.apis.shields import Shield, ShieldInput
|
from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||||
|
@ -36,7 +33,6 @@ RoutableObject = Union[
|
||||||
Shield,
|
Shield,
|
||||||
VectorDB,
|
VectorDB,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFn,
|
|
||||||
Benchmark,
|
Benchmark,
|
||||||
Tool,
|
Tool,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
|
@ -49,7 +45,6 @@ RoutableObjectWithProvider = Annotated[
|
||||||
Shield,
|
Shield,
|
||||||
VectorDB,
|
VectorDB,
|
||||||
Dataset,
|
Dataset,
|
||||||
ScoringFn,
|
|
||||||
Benchmark,
|
Benchmark,
|
||||||
Tool,
|
Tool,
|
||||||
ToolGroup,
|
ToolGroup,
|
||||||
|
@ -62,8 +57,6 @@ RoutedProtocol = Union[
|
||||||
Safety,
|
Safety,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
DatasetIO,
|
DatasetIO,
|
||||||
Scoring,
|
|
||||||
Eval,
|
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -195,7 +188,9 @@ a default SQLite store will be used.""",
|
||||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
logging: Optional[LoggingConfig] = Field(
|
||||||
|
default=None, description="Configuration for Llama Stack Logging"
|
||||||
|
)
|
||||||
|
|
||||||
server: ServerConfig = Field(
|
server: ServerConfig = Field(
|
||||||
default_factory=ServerConfig,
|
default_factory=ServerConfig,
|
||||||
|
@ -206,7 +201,9 @@ a default SQLite store will be used.""",
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
||||||
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
|
distribution_spec: DistributionSpec = Field(
|
||||||
|
description="The distribution spec to build including API providers. "
|
||||||
|
)
|
||||||
image_type: str = Field(
|
image_type: str = Field(
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
|
|
|
@ -11,15 +11,12 @@ from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
@ -72,9 +69,6 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
Api.datasetio: DatasetIO,
|
Api.datasetio: DatasetIO,
|
||||||
Api.datasets: Datasets,
|
Api.datasets: Datasets,
|
||||||
Api.scoring: Scoring,
|
|
||||||
Api.scoring_functions: ScoringFunctions,
|
|
||||||
Api.eval: Eval,
|
|
||||||
Api.benchmarks: Benchmarks,
|
Api.benchmarks: Benchmarks,
|
||||||
Api.post_training: PostTraining,
|
Api.post_training: PostTraining,
|
||||||
Api.tool_groups: ToolGroups,
|
Api.tool_groups: ToolGroups,
|
||||||
|
@ -89,12 +83,6 @@ def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
Api.scoring: (
|
|
||||||
ScoringFunctionsProtocolPrivate,
|
|
||||||
ScoringFunctions,
|
|
||||||
Api.scoring_functions,
|
|
||||||
),
|
|
||||||
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +105,9 @@ async def resolve_impls(
|
||||||
2. Sorting them in dependency order.
|
2. Sorting them in dependency order.
|
||||||
3. Instantiating them with required dependencies.
|
3. Instantiating them with required dependencies.
|
||||||
"""
|
"""
|
||||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
routing_table_apis = {
|
||||||
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
|
}
|
||||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||||
|
|
||||||
providers_with_specs = validate_and_prepare_providers(
|
providers_with_specs = validate_and_prepare_providers(
|
||||||
|
@ -125,7 +115,9 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
|
|
||||||
apis_to_serve = run_config.apis or set(
|
apis_to_serve = run_config.apis or set(
|
||||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
list(providers_with_specs.keys())
|
||||||
|
+ [x.value for x in routing_table_apis]
|
||||||
|
+ [x.value for x in router_apis]
|
||||||
)
|
)
|
||||||
|
|
||||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||||
|
@ -135,7 +127,9 @@ async def resolve_impls(
|
||||||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
def specs_for_autorouted_apis(
|
||||||
|
apis_to_serve: List[str] | Set[str],
|
||||||
|
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||||
"""Generates specifications for automatically routed APIs."""
|
"""Generates specifications for automatically routed APIs."""
|
||||||
specs = {}
|
specs = {}
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
|
@ -177,7 +171,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
|
|
||||||
|
|
||||||
def validate_and_prepare_providers(
|
def validate_and_prepare_providers(
|
||||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
run_config: StackRunConfig,
|
||||||
|
provider_registry: ProviderRegistry,
|
||||||
|
routing_table_apis: Set[Api],
|
||||||
|
router_apis: Set[Api],
|
||||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||||
|
@ -185,17 +182,23 @@ def validate_and_prepare_providers(
|
||||||
for api_str, providers in run_config.providers.items():
|
for api_str, providers in run_config.providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
if api in routing_table_apis:
|
if api in routing_table_apis:
|
||||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
raise ValueError(
|
||||||
|
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||||
|
)
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
logger.warning(
|
||||||
|
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
validate_provider(provider, api, provider_registry)
|
validate_provider(provider, api, provider_registry)
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies] + [
|
||||||
|
a.value for a in p.optional_api_dependencies
|
||||||
|
]
|
||||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||||
specs[provider.provider_id] = spec
|
specs[provider.provider_id] = spec
|
||||||
|
|
||||||
|
@ -205,10 +208,14 @@ def validate_and_prepare_providers(
|
||||||
return providers_with_specs
|
return providers_with_specs
|
||||||
|
|
||||||
|
|
||||||
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
def validate_provider(
|
||||||
|
provider: Provider, api: Api, provider_registry: ProviderRegistry
|
||||||
|
):
|
||||||
"""Validates if the provider is allowed and handles deprecations."""
|
"""Validates if the provider is allowed and handles deprecations."""
|
||||||
if provider.provider_type not in provider_registry[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
raise ValueError(
|
||||||
|
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
|
@ -221,7 +228,8 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
||||||
|
|
||||||
|
|
||||||
def sort_providers_by_deps(
|
def sort_providers_by_deps(
|
||||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]],
|
||||||
|
run_config: StackRunConfig,
|
||||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||||
"""Sorts providers based on their dependencies."""
|
"""Sorts providers based on their dependencies."""
|
||||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||||
|
@ -276,11 +284,15 @@ def sort_providers_by_deps(
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_providers(
|
async def instantiate_providers(
|
||||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
sorted_providers: List[Tuple[str, ProviderWithSpec]],
|
||||||
|
router_apis: Set[Api],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: Dict[Api, Any] = {}
|
impls: Dict[Api, Any] = {}
|
||||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {
|
||||||
|
f"inner-{x.value}": {} for x in router_apis
|
||||||
|
}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
@ -289,7 +301,9 @@ async def instantiate_providers(
|
||||||
|
|
||||||
inner_impls = {}
|
inner_impls = {}
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[
|
||||||
|
f"inner-{provider.spec.router_api.value}"
|
||||||
|
]
|
||||||
|
|
||||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||||
|
|
||||||
|
@ -347,7 +361,9 @@ async def instantiate_provider(
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(
|
||||||
|
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
|
||||||
|
)
|
||||||
|
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
|
@ -384,7 +400,10 @@ async def instantiate_provider(
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
if (
|
||||||
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
|
and provider_spec.api in additional_protocols
|
||||||
|
):
|
||||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||||
check_protocol_compliance(impl, additional_api)
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
|
@ -412,12 +431,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
logger.error(
|
||||||
|
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||||
|
)
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
method_owner = next(
|
||||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
(cls for cls in mro if name in cls.__dict__), None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
method_owner is None
|
||||||
|
or method_owner.__name__ == protocol.__name__
|
||||||
|
):
|
||||||
missing_methods.append((name, "not_actually_implemented"))
|
missing_methods.append((name, "not_actually_implemented"))
|
||||||
|
|
||||||
if missing_methods:
|
if missing_methods:
|
||||||
|
|
|
@ -32,7 +32,6 @@ async def get_routing_table_impl(
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
"shields": ShieldsRoutingTable,
|
"shields": ShieldsRoutingTable,
|
||||||
"datasets": DatasetsRoutingTable,
|
"datasets": DatasetsRoutingTable,
|
||||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
|
||||||
"benchmarks": BenchmarksRoutingTable,
|
"benchmarks": BenchmarksRoutingTable,
|
||||||
"tool_groups": ToolGroupsRoutingTable,
|
"tool_groups": ToolGroupsRoutingTable,
|
||||||
}
|
}
|
||||||
|
@ -45,7 +44,9 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
async def get_auto_router_impl(
|
||||||
|
api: Api, routing_table: RoutingTable, deps: Dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
|
|
@ -8,9 +8,9 @@ import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
|
URL,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
@ -94,7 +94,9 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
logger.debug(
|
||||||
|
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
|
||||||
|
)
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -112,7 +114,9 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
|
||||||
|
vector_db_id, chunks, ttl_seconds
|
||||||
|
)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
@ -121,7 +125,9 @@ class VectorIORouter(VectorIO):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
|
||||||
|
vector_db_id, query, params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
@ -158,7 +164,9 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(
|
||||||
|
model_id, provider_model_id, provider_id, metadata, model_type
|
||||||
|
)
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
@ -212,11 +220,16 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[MetricInResponse]:
|
) -> List[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens, completion_tokens, total_tokens, model
|
||||||
|
)
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
await self.telemetry.log_event(metric)
|
await self.telemetry.log_event(metric)
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
return [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in metrics
|
||||||
|
]
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -241,7 +254,9 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[
|
||||||
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
|
]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
)
|
)
|
||||||
|
@ -251,12 +266,19 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
if (
|
||||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
tool_prompt_format
|
||||||
|
and tool_prompt_format != tool_config.tool_prompt_format
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params = {}
|
params = {}
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
@ -274,9 +296,14 @@ class InferenceRouter(Inference):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# verify tool_choice is one of the tools
|
# verify tool_choice is one of the tools
|
||||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
tool_names = [
|
||||||
|
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||||
|
for t in tools
|
||||||
|
]
|
||||||
if tool_config.tool_choice not in tool_names:
|
if tool_config.tool_choice not in tool_names:
|
||||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
raise ValueError(
|
||||||
|
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||||
|
)
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -291,17 +318,25 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
prompt_tokens = await self._count_tokens(
|
||||||
|
messages, tool_config.tool_prompt_format
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in await provider.chat_completion(**params):
|
async for chunk in await provider.chat_completion(**params):
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
if (
|
||||||
|
chunk.event.event_type
|
||||||
|
== ChatCompletionResponseEventType.progress
|
||||||
|
):
|
||||||
if chunk.event.delta.type == "text":
|
if chunk.event.delta.type == "text":
|
||||||
completion_text += chunk.event.delta.text
|
completion_text += chunk.event.delta.text
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
if (
|
||||||
|
chunk.event.event_type
|
||||||
|
== ChatCompletionResponseEventType.complete
|
||||||
|
):
|
||||||
completion_tokens = await self._count_tokens(
|
completion_tokens = await self._count_tokens(
|
||||||
[
|
[
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
@ -318,7 +353,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
chunk.metrics = (
|
||||||
|
metrics
|
||||||
|
if chunk.metrics is None
|
||||||
|
else chunk.metrics + metrics
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -335,7 +374,9 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = (
|
||||||
|
metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -356,7 +397,9 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -376,7 +419,11 @@ class InferenceRouter(Inference):
|
||||||
async for chunk in await provider.completion(**params):
|
async for chunk in await provider.completion(**params):
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
if (
|
||||||
|
hasattr(chunk, "stop_reason")
|
||||||
|
and chunk.stop_reason
|
||||||
|
and self.telemetry
|
||||||
|
):
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
metrics = await self._compute_and_log_token_usage(
|
metrics = await self._compute_and_log_token_usage(
|
||||||
|
@ -385,7 +432,11 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
chunk.metrics = (
|
||||||
|
metrics
|
||||||
|
if chunk.metrics is None
|
||||||
|
else chunk.metrics + metrics
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
@ -399,7 +450,9 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = (
|
||||||
|
metrics if response.metrics is None else response.metrics + metrics
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
@ -415,7 +468,9 @@ class InferenceRouter(Inference):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
@ -449,7 +504,9 @@ class SafetyRouter(Safety):
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(
|
||||||
|
shield_id, provider_shield_id, provider_id, params
|
||||||
|
)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
@ -521,135 +578,6 @@ class DatasetIORouter(DatasetIO):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ScoringRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def score_batch(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
|
||||||
) -> ScoreBatchResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
|
||||||
res = {}
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
if save_results_dataset:
|
|
||||||
raise NotImplementedError("Save results dataset not implemented yet")
|
|
||||||
|
|
||||||
return ScoreBatchResponse(
|
|
||||||
results=res,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score(
|
|
||||||
self,
|
|
||||||
input_rows: List[Dict[str, Any]],
|
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
|
||||||
res = {}
|
|
||||||
# look up and map each scoring function to its provider impl
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
return ScoreResponse(results=res)
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRouter(Eval):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing EvalRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("EvalRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("EvalRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run_eval(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def evaluate_rows(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: List[Dict[str, Any]],
|
|
||||||
scoring_functions: List[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_status(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> Optional[JobStatus]:
|
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
|
||||||
|
|
||||||
async def job_cancel(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_result(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
class RagToolImpl(RAGToolRuntime):
|
class RagToolImpl(RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -679,9 +607,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
)
|
)
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
return await self.routing_table.get_provider_impl(
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
"insert_into_memory"
|
||||||
)
|
).insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -714,4 +642,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||||
|
tool_group_id, mcp_endpoint
|
||||||
|
)
|
||||||
|
|
|
@ -418,50 +418,6 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
await self.unregister_object(dataset)
|
await self.unregister_object(dataset)
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
||||||
return ListScoringFunctionsResponse(
|
|
||||||
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
|
||||||
scoring_fn = await self.get_object_by_identifier(
|
|
||||||
"scoring_function", scoring_fn_id
|
|
||||||
)
|
|
||||||
if scoring_fn is None:
|
|
||||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
|
||||||
return scoring_fn
|
|
||||||
|
|
||||||
async def register_scoring_function(
|
|
||||||
self,
|
|
||||||
scoring_fn_id: str,
|
|
||||||
description: str,
|
|
||||||
return_type: ParamType,
|
|
||||||
provider_scoring_fn_id: Optional[str] = None,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
params: Optional[ScoringFnParams] = None,
|
|
||||||
) -> None:
|
|
||||||
if provider_scoring_fn_id is None:
|
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
scoring_fn = ScoringFn(
|
|
||||||
identifier=scoring_fn_id,
|
|
||||||
description=description,
|
|
||||||
return_type=return_type,
|
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
scoring_fn.provider_id = provider_id
|
|
||||||
await self.register_object(scoring_fn)
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||||
|
|
|
@ -78,12 +78,6 @@ RESOURCES = [
|
||||||
("shields", Api.shields, "register_shield", "list_shields"),
|
("shields", Api.shields, "register_shield", "list_shields"),
|
||||||
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"),
|
||||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||||
(
|
|
||||||
"scoring_fns",
|
|
||||||
Api.scoring_functions,
|
|
||||||
"register_scoring_function",
|
|
||||||
"list_scoring_functions",
|
|
||||||
),
|
|
||||||
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
|
||||||
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||||
]
|
]
|
||||||
|
|
|
@ -22,11 +22,16 @@ class LlamaStackApi:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
def run_scoring(
|
||||||
|
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
|
||||||
|
):
|
||||||
"""Run scoring on a single row"""
|
"""Run scoring on a single row"""
|
||||||
if not scoring_params:
|
if not scoring_params:
|
||||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||||
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
|
||||||
|
# TODO(xiyan): fix this
|
||||||
|
# return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
|
||||||
|
raise NotImplementedError("Scoring is not implemented")
|
||||||
|
|
||||||
|
|
||||||
llama_stack_api = LlamaStackApi()
|
llama_stack_api = LlamaStackApi()
|
||||||
|
|
|
@ -4,14 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from streamlit_option_menu import option_menu
|
|
||||||
|
|
||||||
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
from llama_stack.distribution.ui.page.distribution.datasets import datasets
|
||||||
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks
|
||||||
from llama_stack.distribution.ui.page.distribution.models import models
|
from llama_stack.distribution.ui.page.distribution.models import models
|
||||||
from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions
|
|
||||||
from llama_stack.distribution.ui.page.distribution.shields import shields
|
from llama_stack.distribution.ui.page.distribution.shields import shields
|
||||||
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs
|
||||||
|
from streamlit_option_menu import option_menu
|
||||||
|
|
||||||
|
|
||||||
def resources_page():
|
def resources_page():
|
||||||
|
@ -43,8 +41,9 @@ def resources_page():
|
||||||
datasets()
|
datasets()
|
||||||
elif selected_resource == "Models":
|
elif selected_resource == "Models":
|
||||||
models()
|
models()
|
||||||
elif selected_resource == "Scoring Functions":
|
# TODO(xiyan): fix this
|
||||||
scoring_functions()
|
# elif selected_resource == "Scoring Functions":
|
||||||
|
# scoring_functions()
|
||||||
elif selected_resource == "Shields":
|
elif selected_resource == "Shields":
|
||||||
shields()
|
shields()
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.benchmarks import Benchmark
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool
|
from llama_stack.apis.tools import Tool
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
@ -42,12 +41,6 @@ class DatasetsProtocolPrivate(Protocol):
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None: ...
|
async def unregister_dataset(self, dataset_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
|
||||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
|
||||||
|
|
||||||
async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksProtocolPrivate(Protocol):
|
class BenchmarksProtocolPrivate(Protocol):
|
||||||
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
|
||||||
|
|
||||||
|
|
|
@ -20,5 +20,7 @@ context_entity_recall_fn_def = ScoringFn(
|
||||||
provider_id="braintrust",
|
provider_id="braintrust",
|
||||||
provider_resource_id="context-entity-recall",
|
provider_resource_id="context-entity-recall",
|
||||||
return_type=NumberType(),
|
return_type=NumberType(),
|
||||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]),
|
params=BasicScoringFnParams(
|
||||||
|
aggregation_functions=[AggregationFunctionType.average]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue