mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
fix: clear model cache when run.yaml model list changes
This commit is contained in:
parent
521865c388
commit
9e79e917f6
5 changed files with 99 additions and 3 deletions
|
@ -101,6 +101,15 @@ TEST_RECORDING_CONTEXT = None
|
|||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
# Set the run config on the models routing table for generating from_config models
|
||||
if Api.models in impls:
|
||||
models_impl = impls[Api.models]
|
||||
models_impl.current_run_config = run_config
|
||||
# Clean up models from disabled providers
|
||||
await models_impl.cleanup_disabled_provider_models()
|
||||
# Register from_config models
|
||||
await models_impl.register_from_config_models()
|
||||
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
|
@ -118,7 +127,16 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
||||
kwargs = {k: getattr(obj, k) for k in obj.model_dump().keys()}
|
||||
|
||||
# Skip registering from_config models since they are registered through the routing table's set_run_config
|
||||
if rsrc == "models":
|
||||
logger.debug(
|
||||
f"Skipping registration of from_config model {obj.model_id} - will be registered through routing table"
|
||||
)
|
||||
continue
|
||||
|
||||
await method(**kwargs)
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
response = await method()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue