feat: associated models API with post_training

there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc
introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-05-30 12:05:33 -04:00
parent 63a9f08c9e
commit 71caa271ad
11 changed files with 393 additions and 23 deletions

View file

@ -33,7 +33,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
if api == Api.inference or api == Api.post_training:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
@ -55,7 +55,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
elif api == Api.inference or api == Api.post_training:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
@ -89,11 +89,18 @@ class CommonRoutingTableImpl(RoutingTable):
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Import routing table classes here to avoid circular imports
from .models import InferenceModelsRoutingTable
from .post_training_models import PostTrainingModelsRoutingTable
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
if api == Api.inference or api == Api.post_training:
# For models, we need to handle both inference and post-training providers
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
# Set the model store for both types of providers
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
@ -116,15 +123,16 @@ class CommonRoutingTableImpl(RoutingTable):
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
from .models import InferenceModelsRoutingTable
from .post_training_models import PostTrainingModelsRoutingTable
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
return ("Models", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
@ -155,7 +163,25 @@ class CommonRoutingTableImpl(RoutingTable):
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
provider = self.impls_by_provider_id[obj.provider_id]
# Check if the provider supports the requested API
if not hasattr(provider, "__provider_spec__"):
return provider
api = provider.__provider_spec__.api
# Only check API compatibility for model routing tables
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
if api not in [Api.inference, Api.post_training]:
raise ValueError(f"Provider {obj.provider_id} does not support the requested API")
# If we have both inference and post-training providers, prefer inference for model registration
if api == Api.post_training and Api.inference in [
p.__provider_spec__.api for p in self.impls_by_provider_id.values()
]:
# Try to find an inference provider first
for _, p in self.impls_by_provider_id.items():
if hasattr(p, "__provider_spec__") and p.__provider_spec__.api == Api.inference:
return p
return provider
raise ValueError(f"Provider not found for `{routing_key}`")
@ -198,7 +224,6 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj