mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
models api configure prompts
This commit is contained in:
parent
e2c7a3cea9
commit
4647cc3e08
2 changed files with 75 additions and 3 deletions
|
@ -9,12 +9,11 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.distribution import api_providers, stack_apis
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
|
@ -25,6 +24,69 @@ def make_routing_entry_type(config_class: Any):
|
|||
return BaseModelWithConfig
|
||||
|
||||
|
||||
def configure_models_api(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
from llama_stack.providers.impls.builtin.models.config import (
|
||||
ModelConfigProviderEntry,
|
||||
)
|
||||
from prompt_toolkit import prompt
|
||||
|
||||
cprint(f"Configuring API `models`...\n", "white", attrs=["bold"])
|
||||
# models do not need prompting, we can use the pre-existing configs to populate the models_config
|
||||
provider = spec.providers["models"]
|
||||
models_config_list = []
|
||||
|
||||
# TODO (xiyan): we need to clean up configure with models & routers
|
||||
# check inference api
|
||||
if "inference" in config.apis_to_serve and "inference" in config.provider_map:
|
||||
inference_provider_id = config.provider_map["inference"].provider_id
|
||||
inference_provider_config = config.provider_map["inference"].config
|
||||
|
||||
if inference_provider_id == "meta-reference":
|
||||
core_model_id = inference_provider_config["model"]
|
||||
else:
|
||||
core_model_id = prompt(
|
||||
"Enter model_id your inference is serving",
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
)
|
||||
models_config_list.append(
|
||||
ModelConfigProviderEntry(
|
||||
api="inference",
|
||||
core_model_id=core_model_id,
|
||||
provider_id=inference_provider_id,
|
||||
config=inference_provider_config,
|
||||
)
|
||||
)
|
||||
|
||||
# check safety api for models
|
||||
if "safety" in config.apis_to_serve and "safety" in config.provider_map:
|
||||
safety_provider_id = config.provider_map["safety"].provider_id
|
||||
safety_provider_config = config.provider_map["safety"].config
|
||||
|
||||
if safety_provider_id == "meta-reference":
|
||||
for model_type in ["llama_guard_shield", "prompt_guard_shield"]:
|
||||
if model_type not in safety_provider_config:
|
||||
continue
|
||||
|
||||
core_model_id = safety_provider_config[model_type]["model"]
|
||||
models_config_list.append(
|
||||
ModelConfigProviderEntry(
|
||||
api="safety",
|
||||
core_model_id=core_model_id,
|
||||
provider_id=safety_provider_id,
|
||||
config=safety_provider_config,
|
||||
)
|
||||
)
|
||||
|
||||
config.provider_map["models"] = GenericProviderConfig(
|
||||
provider_id=spec.providers["models"],
|
||||
config={"models_config": models_config_list},
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
|
@ -40,6 +102,10 @@ def configure_api_providers(
|
|||
if api_str not in apis:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
# configure models builtin api last based on existing configs
|
||||
if api_str == "models":
|
||||
continue
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
|
@ -92,4 +158,7 @@ def configure_api_providers(
|
|||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
if "models" in config.apis_to_serve:
|
||||
config = configure_models_api(config, spec)
|
||||
|
||||
return config
|
||||
|
|
|
@ -21,4 +21,7 @@ class ModelConfigProviderEntry(GenericProviderConfig):
|
|||
|
||||
@json_schema_type
|
||||
class BuiltinImplConfig(BaseModel):
|
||||
models_config: List[ModelConfigProviderEntry]
|
||||
models_config: List[ModelConfigProviderEntry] = Field(
|
||||
default_factory=list,
|
||||
description="list of model config entries for each model",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue