mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-17 02:18:13 +00:00
feat: associated models API with post_training
there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
63a9f08c9e
commit
71caa271ad
11 changed files with 393 additions and 23 deletions
|
@ -21,7 +21,8 @@ async def get_routing_table_impl(
|
|||
) -> Any:
|
||||
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||
from ..routing_tables.models import ModelsRoutingTable
|
||||
from ..routing_tables.models import InferenceModelsRoutingTable
|
||||
from ..routing_tables.post_training_models import PostTrainingModelsRoutingTable
|
||||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
@ -29,7 +30,8 @@ async def get_routing_table_impl(
|
|||
|
||||
api_to_tables = {
|
||||
"vector_dbs": VectorDBsRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"models": InferenceModelsRoutingTable,
|
||||
"post_training_models": PostTrainingModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
|
@ -40,7 +42,12 @@ async def get_routing_table_impl(
|
|||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
# For post-training API, we want to use the post-training models routing table
|
||||
if api == Api.post_training:
|
||||
impl = PostTrainingModelsRoutingTable(impls_by_provider_id, dist_registry)
|
||||
else:
|
||||
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
@ -51,6 +58,7 @@ async def get_auto_router_impl(
|
|||
from .datasets import DatasetIORouter
|
||||
from .eval_scoring import EvalRouter, ScoringRouter
|
||||
from .inference import InferenceRouter
|
||||
from .post_training import PostTrainingRouter
|
||||
from .safety import SafetyRouter
|
||||
from .tool_runtime import ToolRuntimeRouter
|
||||
from .vector_io import VectorIORouter
|
||||
|
@ -63,6 +71,7 @@ async def get_auto_router_impl(
|
|||
"scoring": ScoringRouter,
|
||||
"eval": EvalRouter,
|
||||
"tool_runtime": ToolRuntimeRouter,
|
||||
"post_training": PostTrainingRouter,
|
||||
}
|
||||
api_to_deps = {
|
||||
"inference": {"telemetry": Api.telemetry},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue