feat: associated models API with post_training

there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc
introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-05-30 12:05:33 -04:00
parent 63a9f08c9e
commit 71caa271ad
11 changed files with 393 additions and 23 deletions

View file

@ -21,7 +21,8 @@ async def get_routing_table_impl(
) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable
from ..routing_tables.models import ModelsRoutingTable
from ..routing_tables.models import InferenceModelsRoutingTable
from ..routing_tables.post_training_models import PostTrainingModelsRoutingTable
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
@ -29,7 +30,8 @@ async def get_routing_table_impl(
api_to_tables = {
"vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable,
"models": InferenceModelsRoutingTable,
"post_training_models": PostTrainingModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
@ -40,7 +42,12 @@ async def get_routing_table_impl(
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
# For post-training API, we want to use the post-training models routing table
if api == Api.post_training:
impl = PostTrainingModelsRoutingTable(impls_by_provider_id, dist_registry)
else:
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
await impl.initialize()
return impl
@ -51,6 +58,7 @@ async def get_auto_router_impl(
from .datasets import DatasetIORouter
from .eval_scoring import EvalRouter, ScoringRouter
from .inference import InferenceRouter
from .post_training import PostTrainingRouter
from .safety import SafetyRouter
from .tool_runtime import ToolRuntimeRouter
from .vector_io import VectorIORouter
@ -63,6 +71,7 @@ async def get_auto_router_impl(
"scoring": ScoringRouter,
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
"post_training": PostTrainingRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},