forked from phoenix-oss/llama-stack-mirror
		
	* wip * scoring fn api * eval api * eval task * evaluate api update * pre commit * unwrap context -> config * config field doc * typo * naming fix * separate benchmark / app eval * api name * rename * wip tests * wip * datasetio test * delete unused * fixture * scoring resolve * fix scoring register * scoring test pass * score batch * scoring fix * fix eval * test eval works * remove type ignore * api refactor * add default task_eval_id for routing * add eval_id for jobs * remove type ignore * only keep 1 run_eval * fix optional * register task required * register task required * delete old tests * delete old tests * fixture return impl
		
			
				
	
	
		
			69 lines
		
	
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			69 lines
		
	
	
	
		
			1.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| from typing import Any
 | |
| 
 | |
| from llama_stack.distribution.datatypes import *  # noqa: F403
 | |
| 
 | |
| from llama_stack.distribution.store import DistributionRegistry
 | |
| 
 | |
| from .routing_tables import (
 | |
|     DatasetsRoutingTable,
 | |
|     EvalTasksRoutingTable,
 | |
|     MemoryBanksRoutingTable,
 | |
|     ModelsRoutingTable,
 | |
|     ScoringFunctionsRoutingTable,
 | |
|     ShieldsRoutingTable,
 | |
| )
 | |
| 
 | |
| 
 | |
| async def get_routing_table_impl(
 | |
|     api: Api,
 | |
|     impls_by_provider_id: Dict[str, RoutedProtocol],
 | |
|     _deps,
 | |
|     dist_registry: DistributionRegistry,
 | |
| ) -> Any:
 | |
|     api_to_tables = {
 | |
|         "memory_banks": MemoryBanksRoutingTable,
 | |
|         "models": ModelsRoutingTable,
 | |
|         "shields": ShieldsRoutingTable,
 | |
|         "datasets": DatasetsRoutingTable,
 | |
|         "scoring_functions": ScoringFunctionsRoutingTable,
 | |
|         "eval_tasks": EvalTasksRoutingTable,
 | |
|     }
 | |
| 
 | |
|     if api.value not in api_to_tables:
 | |
|         raise ValueError(f"API {api.value} not found in router map")
 | |
| 
 | |
|     impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
 | |
|     await impl.initialize()
 | |
|     return impl
 | |
| 
 | |
| 
 | |
| async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
 | |
|     from .routers import (
 | |
|         DatasetIORouter,
 | |
|         EvalRouter,
 | |
|         InferenceRouter,
 | |
|         MemoryRouter,
 | |
|         SafetyRouter,
 | |
|         ScoringRouter,
 | |
|     )
 | |
| 
 | |
|     api_to_routers = {
 | |
|         "memory": MemoryRouter,
 | |
|         "inference": InferenceRouter,
 | |
|         "safety": SafetyRouter,
 | |
|         "datasetio": DatasetIORouter,
 | |
|         "scoring": ScoringRouter,
 | |
|         "eval": EvalRouter,
 | |
|     }
 | |
|     if api.value not in api_to_routers:
 | |
|         raise ValueError(f"API {api.value} not found in router map")
 | |
| 
 | |
|     impl = api_to_routers[api.value](routing_table)
 | |
|     await impl.initialize()
 | |
|     return impl
 |