mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-30 11:50:14 +00:00
feat: associated models API with post_training
there are likely scenarios where admins of a stack only want to allow clients to fine-tune certain models, register certain models to be fine-tuned. etc introduce the post_training router and post_training_models as the associated type. A different model type needs to be used for inference vs post_training due to the structure of the router currently. Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
63a9f08c9e
commit
71caa271ad
11 changed files with 393 additions and 23 deletions
|
@ -67,6 +67,7 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
Api.post_training_models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
Api.telemetry: Telemetry,
|
||||
|
@ -93,6 +94,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
|||
def additional_protocols_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.post_training: (ModelsProtocolPrivate, Models, Api.post_training_models),
|
||||
Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
|
@ -251,6 +253,8 @@ async def instantiate_providers(
|
|||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
|
||||
# First pass: instantiate all providers
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
|
@ -269,6 +273,10 @@ async def instantiate_providers(
|
|||
api = Api(api_str)
|
||||
impls[api] = impl
|
||||
|
||||
# Second pass: connect routing tables
|
||||
if Api.models in impls and Api.post_training_models in impls:
|
||||
impls[Api.models].post_training_models_table = impls[Api.post_training_models]
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue