mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
models api configure prompts
This commit is contained in:
parent
e2c7a3cea9
commit
4647cc3e08
2 changed files with 75 additions and 3 deletions
|
@ -9,12 +9,11 @@ from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_providers, stack_apis
|
from llama_stack.distribution.distribution import api_providers, stack_apis
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
def make_routing_entry_type(config_class: Any):
|
def make_routing_entry_type(config_class: Any):
|
||||||
|
@ -25,6 +24,69 @@ def make_routing_entry_type(config_class: Any):
|
||||||
return BaseModelWithConfig
|
return BaseModelWithConfig
|
||||||
|
|
||||||
|
|
||||||
|
def configure_models_api(
|
||||||
|
config: StackRunConfig, spec: DistributionSpec
|
||||||
|
) -> StackRunConfig:
|
||||||
|
from llama_stack.providers.impls.builtin.models.config import (
|
||||||
|
ModelConfigProviderEntry,
|
||||||
|
)
|
||||||
|
from prompt_toolkit import prompt
|
||||||
|
|
||||||
|
cprint(f"Configuring API `models`...\n", "white", attrs=["bold"])
|
||||||
|
# models do not need prompting, we can use the pre-existing configs to populate the models_config
|
||||||
|
provider = spec.providers["models"]
|
||||||
|
models_config_list = []
|
||||||
|
|
||||||
|
# TODO (xiyan): we need to clean up configure with models & routers
|
||||||
|
# check inference api
|
||||||
|
if "inference" in config.apis_to_serve and "inference" in config.provider_map:
|
||||||
|
inference_provider_id = config.provider_map["inference"].provider_id
|
||||||
|
inference_provider_config = config.provider_map["inference"].config
|
||||||
|
|
||||||
|
if inference_provider_id == "meta-reference":
|
||||||
|
core_model_id = inference_provider_config["model"]
|
||||||
|
else:
|
||||||
|
core_model_id = prompt(
|
||||||
|
"Enter model_id your inference is serving",
|
||||||
|
default="Meta-Llama3.1-8B-Instruct",
|
||||||
|
)
|
||||||
|
models_config_list.append(
|
||||||
|
ModelConfigProviderEntry(
|
||||||
|
api="inference",
|
||||||
|
core_model_id=core_model_id,
|
||||||
|
provider_id=inference_provider_id,
|
||||||
|
config=inference_provider_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# check safety api for models
|
||||||
|
if "safety" in config.apis_to_serve and "safety" in config.provider_map:
|
||||||
|
safety_provider_id = config.provider_map["safety"].provider_id
|
||||||
|
safety_provider_config = config.provider_map["safety"].config
|
||||||
|
|
||||||
|
if safety_provider_id == "meta-reference":
|
||||||
|
for model_type in ["llama_guard_shield", "prompt_guard_shield"]:
|
||||||
|
if model_type not in safety_provider_config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
core_model_id = safety_provider_config[model_type]["model"]
|
||||||
|
models_config_list.append(
|
||||||
|
ModelConfigProviderEntry(
|
||||||
|
api="safety",
|
||||||
|
core_model_id=core_model_id,
|
||||||
|
provider_id=safety_provider_id,
|
||||||
|
config=safety_provider_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
config.provider_map["models"] = GenericProviderConfig(
|
||||||
|
provider_id=spec.providers["models"],
|
||||||
|
config={"models_config": models_config_list},
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
# TODO: make sure we can deal with existing configuration values correctly
|
# TODO: make sure we can deal with existing configuration values correctly
|
||||||
# instead of just overwriting them
|
# instead of just overwriting them
|
||||||
def configure_api_providers(
|
def configure_api_providers(
|
||||||
|
@ -40,6 +102,10 @@ def configure_api_providers(
|
||||||
if api_str not in apis:
|
if api_str not in apis:
|
||||||
raise ValueError(f"Unknown API `{api_str}`")
|
raise ValueError(f"Unknown API `{api_str}`")
|
||||||
|
|
||||||
|
# configure models builtin api last based on existing configs
|
||||||
|
if api_str == "models":
|
||||||
|
continue
|
||||||
|
|
||||||
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
|
@ -92,4 +158,7 @@ def configure_api_providers(
|
||||||
config=cfg.dict(),
|
config=cfg.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "models" in config.apis_to_serve:
|
||||||
|
config = configure_models_api(config, spec)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
|
@ -21,4 +21,7 @@ class ModelConfigProviderEntry(GenericProviderConfig):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BuiltinImplConfig(BaseModel):
|
class BuiltinImplConfig(BaseModel):
|
||||||
models_config: List[ModelConfigProviderEntry]
|
models_config: List[ModelConfigProviderEntry] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="list of model config entries for each model",
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue