forked from phoenix-oss/llama-stack-mirror
fix precommit
This commit is contained in:
parent
81bc051411
commit
97e7717c9b
3 changed files with 71 additions and 3 deletions
|
@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.evaluation import Evaluation
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
|
@ -35,6 +36,7 @@ from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
|
BenchmarksProtocolPrivate,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
InlineProviderSpec,
|
InlineProviderSpec,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
|
@ -71,6 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.post_training: PostTraining,
|
Api.post_training: PostTraining,
|
||||||
Api.tool_groups: ToolGroups,
|
Api.tool_groups: ToolGroups,
|
||||||
Api.tool_runtime: ToolRuntime,
|
Api.tool_runtime: ToolRuntime,
|
||||||
|
Api.evaluation: Evaluation,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,6 +84,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
|
Api.evaluation: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ async def get_routing_table_impl(
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
|
EvaluationRouter,
|
||||||
InferenceRouter,
|
InferenceRouter,
|
||||||
SafetyRouter,
|
SafetyRouter,
|
||||||
ToolRuntimeRouter,
|
ToolRuntimeRouter,
|
||||||
|
@ -58,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
|
||||||
"safety": SafetyRouter,
|
"safety": SafetyRouter,
|
||||||
"datasetio": DatasetIORouter,
|
"datasetio": DatasetIORouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
|
"evaluation": EvaluationRouter,
|
||||||
}
|
}
|
||||||
api_to_deps = {
|
api_to_deps = {
|
||||||
"inference": {"telemetry": Api.telemetry},
|
"inference": {"telemetry": Api.telemetry},
|
||||||
|
|
|
@ -7,13 +7,21 @@
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import Dataset, DatasetPurpose, DataSource
|
||||||
|
from llama_stack.apis.evaluation import (
|
||||||
|
Evaluation,
|
||||||
|
EvaluationCandidate,
|
||||||
|
EvaluationJob,
|
||||||
|
EvaluationResponse,
|
||||||
|
EvaluationTask,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -474,11 +482,11 @@ class DatasetIORouter(DatasetIO):
|
||||||
source: DataSource,
|
source: DataSource,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
dataset_id: Optional[str] = None,
|
dataset_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> Dataset:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_dataset(
|
return await self.routing_table.register_dataset(
|
||||||
purpose=purpose,
|
purpose=purpose,
|
||||||
source=source,
|
source=source,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
@ -573,3 +581,57 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationRouter(Evaluation):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing EvaluationRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("EvaluationRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("EvaluationRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
grader_ids: List[str],
|
||||||
|
benchmark_id: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Benchmark:
|
||||||
|
logger.debug(
|
||||||
|
f"EvaluationRouter.register_benchmark: {benchmark_id=} {dataset_id=} {grader_ids=} {metadata=}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.register_benchmark(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
grader_ids=grader_ids,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
task: EvaluationTask,
|
||||||
|
candidate: EvaluationCandidate,
|
||||||
|
) -> EvaluationJob:
|
||||||
|
raise NotImplementedError("Run is not implemented yet")
|
||||||
|
|
||||||
|
async def run_sync(
|
||||||
|
self,
|
||||||
|
task: EvaluationTask,
|
||||||
|
candidate: EvaluationCandidate,
|
||||||
|
) -> EvaluationResponse:
|
||||||
|
raise NotImplementedError("Run sync is not implemented yet")
|
||||||
|
|
||||||
|
async def grade(self, task: EvaluationTask) -> EvaluationJob:
|
||||||
|
raise NotImplementedError("Grade is not implemented yet")
|
||||||
|
|
||||||
|
async def grade_sync(self, task: EvaluationTask) -> EvaluationResponse:
|
||||||
|
raise NotImplementedError("Grade sync is not implemented yet")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue