diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index ab1f31de6..672b9fc78 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -9,12 +9,11 @@ from typing import Any from pydantic import BaseModel from llama_stack.distribution.datatypes import * # noqa: F403 -from termcolor import cprint - from llama_stack.distribution.distribution import api_providers, stack_apis from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config +from termcolor import cprint def make_routing_entry_type(config_class: Any): @@ -25,6 +24,69 @@ def make_routing_entry_type(config_class: Any): return BaseModelWithConfig +def configure_models_api( + config: StackRunConfig, spec: DistributionSpec +) -> StackRunConfig: + from llama_stack.providers.impls.builtin.models.config import ( + ModelConfigProviderEntry, + ) + from prompt_toolkit import prompt + + cprint(f"Configuring API `models`...\n", "white", attrs=["bold"]) + # models do not need prompting, we can use the pre-existing configs to populate the models_config + provider = spec.providers["models"] + models_config_list = [] + + # TODO (xiyan): we need to clean up configure with models & routers + # check inference api + if "inference" in config.apis_to_serve and "inference" in config.provider_map: + inference_provider_id = config.provider_map["inference"].provider_id + inference_provider_config = config.provider_map["inference"].config + + if inference_provider_id == "meta-reference": + core_model_id = inference_provider_config["model"] + else: + core_model_id = prompt( + "Enter model_id your inference is serving", + default="Meta-Llama3.1-8B-Instruct", + ) + models_config_list.append( + ModelConfigProviderEntry( + api="inference", + core_model_id=core_model_id, + provider_id=inference_provider_id, + config=inference_provider_config, + ) + ) + + # check safety api for models + if "safety" in config.apis_to_serve and "safety" in config.provider_map: + safety_provider_id = config.provider_map["safety"].provider_id + safety_provider_config = config.provider_map["safety"].config + + if safety_provider_id == "meta-reference": + for model_type in ["llama_guard_shield", "prompt_guard_shield"]: + if model_type not in safety_provider_config: + continue + + core_model_id = safety_provider_config[model_type]["model"] + models_config_list.append( + ModelConfigProviderEntry( + api="safety", + core_model_id=core_model_id, + provider_id=safety_provider_id, + config=safety_provider_config, + ) + ) + + config.provider_map["models"] = GenericProviderConfig( + provider_id=spec.providers["models"], + config={"models_config": models_config_list}, + ) + + return config + + # TODO: make sure we can deal with existing configuration values correctly # instead of just overwriting them def configure_api_providers( @@ -40,6 +102,10 @@ def configure_api_providers( if api_str not in apis: raise ValueError(f"Unknown API `{api_str}`") + # configure models builtin api last based on existing configs + if api_str == "models": + continue + cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"]) api = Api(api_str) @@ -92,4 +158,7 @@ def configure_api_providers( config=cfg.dict(), ) + if "models" in config.apis_to_serve: + config = configure_models_api(config, spec) + return config diff --git a/llama_stack/providers/impls/builtin/models/config.py b/llama_stack/providers/impls/builtin/models/config.py index d153d6075..a2db5efbd 100644 --- a/llama_stack/providers/impls/builtin/models/config.py +++ b/llama_stack/providers/impls/builtin/models/config.py @@ -21,4 +21,7 @@ class ModelConfigProviderEntry(GenericProviderConfig): @json_schema_type class BuiltinImplConfig(BaseModel): - models_config: List[ModelConfigProviderEntry] + models_config: List[ModelConfigProviderEntry] = Field( + default_factory=list, + description="list of model config entries for each model", + )