models api configure prompts

This commit is contained in:
Xi Yan 2024-09-19 23:22:43 -07:00
parent e2c7a3cea9
commit 4647cc3e08
2 changed files with 75 additions and 3 deletions

View file

@ -9,12 +9,11 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.distribution import api_providers, stack_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from termcolor import cprint
def make_routing_entry_type(config_class: Any):
@ -25,6 +24,69 @@ def make_routing_entry_type(config_class: Any):
return BaseModelWithConfig
def configure_models_api(
config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig:
from llama_stack.providers.impls.builtin.models.config import (
ModelConfigProviderEntry,
)
from prompt_toolkit import prompt
cprint(f"Configuring API `models`...\n", "white", attrs=["bold"])
# models do not need prompting, we can use the pre-existing configs to populate the models_config
provider = spec.providers["models"]
models_config_list = []
# TODO (xiyan): we need to clean up configure with models & routers
# check inference api
if "inference" in config.apis_to_serve and "inference" in config.provider_map:
inference_provider_id = config.provider_map["inference"].provider_id
inference_provider_config = config.provider_map["inference"].config
if inference_provider_id == "meta-reference":
core_model_id = inference_provider_config["model"]
else:
core_model_id = prompt(
"Enter model_id your inference is serving",
default="Meta-Llama3.1-8B-Instruct",
)
models_config_list.append(
ModelConfigProviderEntry(
api="inference",
core_model_id=core_model_id,
provider_id=inference_provider_id,
config=inference_provider_config,
)
)
# check safety api for models
if "safety" in config.apis_to_serve and "safety" in config.provider_map:
safety_provider_id = config.provider_map["safety"].provider_id
safety_provider_config = config.provider_map["safety"].config
if safety_provider_id == "meta-reference":
for model_type in ["llama_guard_shield", "prompt_guard_shield"]:
if model_type not in safety_provider_config:
continue
core_model_id = safety_provider_config[model_type]["model"]
models_config_list.append(
ModelConfigProviderEntry(
api="safety",
core_model_id=core_model_id,
provider_id=safety_provider_id,
config=safety_provider_config,
)
)
config.provider_map["models"] = GenericProviderConfig(
provider_id=spec.providers["models"],
config={"models_config": models_config_list},
)
return config
# TODO: make sure we can deal with existing configuration values correctly
# instead of just overwriting them
def configure_api_providers(
@ -40,6 +102,10 @@ def configure_api_providers(
if api_str not in apis:
raise ValueError(f"Unknown API `{api_str}`")
# configure models builtin api last based on existing configs
if api_str == "models":
continue
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
api = Api(api_str)
@ -92,4 +158,7 @@ def configure_api_providers(
config=cfg.dict(),
)
if "models" in config.apis_to_serve:
config = configure_models_api(config, spec)
return config

View file

@ -21,4 +21,7 @@ class ModelConfigProviderEntry(GenericProviderConfig):
@json_schema_type
class BuiltinImplConfig(BaseModel):
models_config: List[ModelConfigProviderEntry]
models_config: List[ModelConfigProviderEntry] = Field(
default_factory=list,
description="list of model config entries for each model",
)