mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
feat: associated models API with post_training
there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
63a9f08c9e
commit
71caa271ad
11 changed files with 393 additions and 23 deletions
|
@ -33,7 +33,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
if api == Api.inference or api == Api.post_training:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
|
@ -55,7 +55,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
api = get_impl_api(p)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_db(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
elif api == Api.inference or api == Api.post_training:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
|
@ -89,11 +89,18 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Import routing table classes here to avoid circular imports
|
||||
from .models import InferenceModelsRoutingTable
|
||||
from .post_training_models import PostTrainingModelsRoutingTable
|
||||
|
||||
# Register all objects from providers
|
||||
for pid, p in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(p)
|
||||
if api == Api.inference:
|
||||
p.model_store = self
|
||||
if api == Api.inference or api == Api.post_training:
|
||||
# For models, we need to handle both inference and post-training providers
|
||||
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||
# Set the model store for both types of providers
|
||||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.vector_io:
|
||||
|
@ -116,15 +123,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
from .models import ModelsRoutingTable
|
||||
from .models import InferenceModelsRoutingTable
|
||||
from .post_training_models import PostTrainingModelsRoutingTable
|
||||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_dbs import VectorDBsRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||
return ("Models", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorDBsRoutingTable):
|
||||
|
@ -155,7 +163,25 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
)
|
||||
|
||||
if not provider_id or provider_id == obj.provider_id:
|
||||
return self.impls_by_provider_id[obj.provider_id]
|
||||
provider = self.impls_by_provider_id[obj.provider_id]
|
||||
# Check if the provider supports the requested API
|
||||
if not hasattr(provider, "__provider_spec__"):
|
||||
return provider
|
||||
api = provider.__provider_spec__.api
|
||||
|
||||
# Only check API compatibility for model routing tables
|
||||
if isinstance(self, InferenceModelsRoutingTable | PostTrainingModelsRoutingTable):
|
||||
if api not in [Api.inference, Api.post_training]:
|
||||
raise ValueError(f"Provider {obj.provider_id} does not support the requested API")
|
||||
# If we have both inference and post-training providers, prefer inference for model registration
|
||||
if api == Api.post_training and Api.inference in [
|
||||
p.__provider_spec__.api for p in self.impls_by_provider_id.values()
|
||||
]:
|
||||
# Try to find an inference provider first
|
||||
for _, p in self.impls_by_provider_id.items():
|
||||
if hasattr(p, "__provider_spec__") and p.__provider_spec__.api == Api.inference:
|
||||
return p
|
||||
return provider
|
||||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
|
@ -198,7 +224,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue