Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -58,22 +58,14 @@ def get_provider_dependencies(
for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]
providers = (
provider_or_providers
if isinstance(provider_or_providers, list)
else [provider_or_providers]
)
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages)
@ -109,19 +101,13 @@ def build_image(
image_name: str,
template_or_config: str,
):
container_base = (
build_config.distribution_spec.container_image or "python:3.10-slim"
)
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.container.value:
script = str(
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
)
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [
script,
template_or_config,
@ -132,9 +118,7 @@ def build_image(
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.conda.value:
script = str(
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
)
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [
script,
str(image_name),
@ -142,9 +126,7 @@ def build_image(
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.venv.value:
script = str(
importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
)
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [
script,
str(image_name),

View file

@ -68,9 +68,7 @@ def create_api_client_class(protocol) -> Type:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert return_type, (
f"Could not extract return type for {sig.return_annotation}"
)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
@ -87,9 +85,7 @@ def create_api_client_class(protocol) -> Type:
webmethod, sig = self.routes[method_name]
return_type = extract_async_iterator_type(sig.return_annotation)
assert return_type, (
f"Could not extract return type for {sig.return_annotation}"
)
assert return_type, f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(method_name, *args, **kwargs)
@ -204,9 +200,7 @@ async def example(model: str = None):
if not model:
model = "Llama3.2-3B-Instruct"
message = UserMessage(
content="hello world, write me a 2 sentence poem about the moon"
)
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
cprint(f"User>{message.content}", "green")
stream = True

View file

@ -26,9 +26,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
def configure_single_provider(
registry: Dict[str, ProviderSpec], provider: Provider
) -> Provider:
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
@ -47,9 +45,7 @@ def configure_single_provider(
)
def configure_api_providers(
config: StackRunConfig, build_spec: DistributionSpec
) -> StackRunConfig:
def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec) -> StackRunConfig:
is_nux = len(config.providers) == 0
if is_nux:
@ -87,9 +83,7 @@ def configure_api_providers(
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
updated_providers.append(
configure_single_provider(provider_registry[api], p)
)
updated_providers.append(configure_single_provider(provider_registry[api], p))
logger.info("")
else:
# we are newly configuring this API
@ -114,11 +108,7 @@ def configure_api_providers(
configure_single_provider(
provider_registry[api],
Provider(
provider_id=(
f"{provider_type}-{i:02d}"
if len(plist) > 1
else provider_type
),
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type),
provider_type=provider_type,
config={},
),
@ -137,11 +127,7 @@ def upgrade_from_routing_table(
def get_providers(entries):
return [
Provider(
provider_id=(
f"{entry['provider_type']}-{i:02d}"
if len(entries) > 1
else entry["provider_type"]
),
provider_id=(f"{entry['provider_type']}-{i:02d}" if len(entries) > 1 else entry["provider_type"]),
provider_type=entry["provider_type"],
config=entry["config"],
)

View file

@ -163,9 +163,7 @@ a default SQLite store will be used.""",
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(
description="The distribution spec to build including API providers. "
)
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(
default="conda",
description="Type of package to build (conda | container | venv)",

View file

@ -55,9 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -154,9 +154,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
def sync_generator():
try:
async_stream = loop.run_until_complete(
self.async_client.request(*args, **kwargs)
)
async_stream = loop.run_until_complete(self.async_client.request(*args, **kwargs))
while True:
chunk = loop.run_until_complete(async_stream.__anext__())
yield chunk
@ -181,9 +179,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# when using the library client, we should not log to console since many
# of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(
sink for sink in current_sinks if sink != "console"
)
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if config_path_or_template_name.endswith(".yaml"):
config_path = Path(config_path_or_template_name)
@ -202,9 +198,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
async def initialize(self):
try:
self.impls = await construct_stack(
self.config, self.custom_provider_registry
)
self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(
@ -247,9 +241,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][
_convert_path_to_regex(endpoint.route)
] = func
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
return True
@ -266,9 +258,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized")
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
if stream:
response = await self._call_streaming(
@ -408,9 +398,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(
self, path: str, method: str, body: Optional[dict] = None
) -> dict:
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
if not body:
return {}
@ -425,7 +413,5 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(
param.annotation, value
)
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body

View file

@ -115,9 +115,7 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
providers_with_specs = {}
@ -125,16 +123,12 @@ async def resolve_impls(
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if provider.provider_type not in provider_registry[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
@ -145,9 +139,7 @@ async def resolve_impls(
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [
a.value for a in p.optional_api_dependencies
]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
@ -158,9 +150,7 @@ async def resolve_impls(
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
)
for info in builtin_automatically_routed_apis():
@ -197,9 +187,7 @@ async def resolve_impls(
)
}
sorted_providers = topological_sort(
{k: v.values() for k, v in providers_with_specs.items()}
)
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
@ -237,9 +225,7 @@ async def resolve_impls(
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(
provider,
@ -336,10 +322,7 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
@ -367,19 +350,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -85,9 +85,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
vector_db_id, chunks, ttl_seconds
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
@ -95,9 +93,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
vector_db_id, query, params
)
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
class InferenceRouter(Inference):
@ -123,9 +119,7 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion(
self,
@ -143,9 +137,7 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
params = dict(
model_id=model_id,
messages=messages,
@ -176,9 +168,7 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model_id=model_id,
@ -202,9 +192,7 @@ class InferenceRouter(Inference):
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
@ -231,9 +219,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
return await self.routing_table.register_shield(
shield_id, provider_shield_id, provider_id, params
)
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
@ -268,9 +254,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
return await self.routing_table.get_provider_impl(
dataset_id
).get_rows_paginated(
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
@ -305,9 +289,7 @@ class ScoringRouter(Scoring):
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -328,9 +310,7 @@ class ScoringRouter(Scoring):
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
@ -381,9 +361,7 @@ class EvalRouter(Eval):
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
async def job_cancel(
self,
@ -420,9 +398,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
return await self.routing_table.get_provider_impl(
"query_from_memory"
).query(content, vector_db_ids, query_config)
return await self.routing_table.get_provider_impl("query_from_memory").query(
content, vector_db_ids, query_config
)
async def insert(
self,
@ -430,9 +408,9 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
return await self.routing_table.get_provider_impl(
"insert_into_memory"
).insert(documents, vector_db_id, chunk_size_in_tokens)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
def __init__(
self,
@ -460,6 +438,4 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -94,9 +94,7 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
@ -131,9 +129,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
@ -171,9 +167,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
@ -183,13 +177,9 @@ class CommonRoutingTableImpl(RoutingTable):
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -244,9 +234,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = Model(
identifier=model_id,
provider_resource_id=provider_model_id,
@ -266,9 +254,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier("shield", identifier)
@ -340,9 +326,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {embedding_model} does not have an embedding dimension"
)
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
@ -364,9 +348,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier("dataset", dataset_id)
@ -411,9 +393,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
@ -510,12 +490,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs:
tools.append(

View file

@ -43,9 +43,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
if api == Api.tool_runtime:
for tool_group in SpecialToolGroup:
sub_protocol = toolgroup_protocols[tool_group]
sub_protocol_methods = inspect.getmembers(
sub_protocol, predicate=inspect.isfunction
)
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
for name, method in sub_protocol_methods:
if not hasattr(method, "__webmethod__"):
continue

View file

@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
(
param.replace(
annotation=Annotated[
param.annotation, FastapiPath(..., title=param.name)
]
)
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
for param in new_params[1:]
]
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
@ -296,9 +276,7 @@ def main():
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument(
"--env",
action="append",
@ -323,9 +301,7 @@ def main():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
elif args.template:
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
)
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
@ -383,9 +359,7 @@ def main():
impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=UserWarning, module="pydantic._internal._fields"
)
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
@ -416,9 +390,7 @@ def main():
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
]
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
return params

View file

@ -110,9 +110,7 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -187,9 +185,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
return key, value
except ValueError as e:
raise ValueError(
@ -202,20 +198,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
await register_resources(run_config, impls)
return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
with importlib.resources.as_file(template_path) as path:
if not path.exists():

View file

@ -25,9 +25,7 @@ class DistributionRegistry(Protocol):
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
async def update(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
@ -61,9 +59,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None:
pass
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
@ -72,12 +68,8 @@ class DiskDistributionRegistry(DistributionRegistry):
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
if not json_str:
return None
@ -143,9 +135,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]:
@ -153,9 +143,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async with self._locked_cache() as cache:
return list(cache.values())
async def get(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
await self._ensure_initialized()
cache_key = (type, identifier)
@ -197,9 +185,7 @@ async def create_dist_registry(
dist_kvstore = await kvstore_impl(metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
)
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
await dist_registry.initialize()

View file

@ -161,9 +161,7 @@ async def test_duplicate_provider_registration(config):
result = await cached_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert (
result.embedding_model == original_vector_db.embedding_model
) # Original values preserved
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
@ -193,14 +191,9 @@ async def test_get_all_objects(config):
# Verify each vector_db was stored correctly
for original_vector_db in test_vector_dbs:
matching_vector_dbs = [
v for v in all_results if v.identifier == original_vector_db.identifier
]
matching_vector_dbs = [v for v in all_results if v.identifier == original_vector_db.identifier]
assert len(matching_vector_dbs) == 1
stored_vector_db = matching_vector_dbs[0]
assert stored_vector_db.embedding_model == original_vector_db.embedding_model
assert stored_vector_db.provider_id == original_vector_db.provider_id
assert (
stored_vector_db.embedding_dimension
== original_vector_db.embedding_dimension
)
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension

View file

@ -22,15 +22,11 @@ def main():
)
# Playground pages
chat_page = st.Page(
"page/playground/chat.py", title="Chat", icon="💬", default=True
)
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
# Distribution pages
resources_page = st.Page(
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
)
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
provider_page = st.Page(
"page/distribution/providers.py",
title="API Providers",

View file

@ -23,15 +23,11 @@ class LlamaStackApi:
},
)
def run_scoring(
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(
input_rows=[row], scoring_functions=scoring_params
)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)
llama_stack_api = LlamaStackApi()

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
def datasets():
st.header("Datasets")
datasets_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
}
datasets_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()}
if len(datasets_info) > 0:
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
st.json(datasets_info[selected_dataset], expanded=True)

View file

@ -12,12 +12,8 @@ def eval_tasks():
# Eval Tasks Section
st.header("Eval Tasks")
eval_tasks_info = {
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
}
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
if len(eval_tasks_info) > 0:
selected_eval_task = st.selectbox(
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
)
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
st.json(eval_tasks_info[selected_eval_task], expanded=True)

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
}
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -11,12 +11,7 @@ from modules.api import llama_stack_api
def scoring_functions():
st.header("Scoring Functions")
scoring_functions_info = {
s.identifier: s.to_dict()
for s in llama_stack_api.client.scoring_functions.list()
}
scoring_functions_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.scoring_functions.list()}
selected_scoring_function = st.selectbox(
"Select a scoring function", list(scoring_functions_info.keys())
)
selected_scoring_function = st.selectbox("Select a scoring function", list(scoring_functions_info.keys()))
st.json(scoring_functions_info[selected_scoring_function], expanded=True)

View file

@ -12,9 +12,7 @@ def shields():
# Shields Section
st.header("Shields")
shields_info = {
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
}
shields_info = {s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()}
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
st.json(shields_info[selected_shield])

View file

@ -10,14 +10,10 @@ from modules.api import llama_stack_api
def vector_dbs():
st.header("Vector Databases")
vector_dbs_info = {
v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()
}
vector_dbs_info = {v.identifier: v.to_dict() for v in llama_stack_api.client.vector_dbs.list()}
if len(vector_dbs_info) > 0:
selected_vector_db = st.selectbox(
"Select a vector database", list(vector_dbs_info.keys())
)
selected_vector_db = st.selectbox("Select a vector database", list(vector_dbs_info.keys()))
st.json(vector_dbs_info[selected_vector_db])
else:
st.info("No vector databases found")

View file

@ -14,7 +14,6 @@ from modules.utils import process_dataset
def application_evaluation_page():
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Scoring)")
@ -83,9 +82,7 @@ def application_evaluation_page():
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
)
st.error(f"Invalid JSON for **{param_name}** in {scoring_fn_id}")
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
@ -128,9 +125,7 @@ def application_evaluation_page():
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(
score_res.to_json(),
expanded=2,

View file

@ -195,7 +195,6 @@ def run_evaluation_3():
# Add run button and handle evaluation
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.rows
@ -233,9 +232,7 @@ def run_evaluation_3():
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write(
f"Expand to see current processed result ({i + 1} / {len(rows)})"
)
progress_text_container.write(f"Expand to see current processed result ({i + 1} / {len(rows)})")
results_container.json(eval_res, expanded=2)
progress_bar.progress(1.0, text="Evaluation complete!")
@ -247,7 +244,6 @@ def run_evaluation_3():
def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")

View file

@ -11,9 +11,7 @@ from modules.api import llama_stack_api
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [
model.identifier for model in available_models if model.model_type == "llm"
]
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
@ -128,6 +126,4 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
full_response = response
message_placeholder.markdown(full_response.completion_message.content)
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -74,9 +74,7 @@ def rag_chat_page():
)
available_models = llama_stack_api.client.models.list()
available_models = [
model.identifier for model in available_models if model.model_type == "llm"
]
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
@ -137,9 +135,7 @@ def rag_chat_page():
dict(
name="builtin::rag",
args={
"vector_db_ids": [
vector_db_id for vector_db_id in selected_vector_dbs
],
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
},
)
],
@ -186,9 +182,7 @@ def rag_chat_page():
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)
st.session_state.messages.append({"role": "assistant", "content": full_response})
rag_chat_page()

View file

@ -8,9 +8,7 @@ import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(
os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))
)
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"

View file

@ -31,15 +31,11 @@ def is_list_of_primitives(field_type):
def is_basemodel_without_fields(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
)
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) == 0
def can_recurse(typ):
return (
inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
)
return inspect.isclass(typ) and issubclass(typ, BaseModel) and len(typ.__fields__) > 0
def get_literal_values(field):
@ -72,7 +68,7 @@ def is_discriminated_union(typ) -> bool:
if isinstance(typ, FieldInfo):
return typ.discriminator
else:
if not (get_origin(typ) is Annotated):
if get_origin(typ) is not Annotated:
return False
args = get_args(typ)
return len(args) >= 2 and args[1].discriminator
@ -116,9 +112,7 @@ def prompt_for_discriminated_union(
chosen_type = type_map[discriminator_value]
log.info(f"\nConfiguring {chosen_type.__name__}:")
if existing_value and (
getattr(existing_value, discriminator) != discriminator_value
):
if existing_value and (getattr(existing_value, discriminator) != discriminator_value):
existing_value = None
sub_config = prompt_for_config(chosen_type, existing_value)
@ -134,9 +128,7 @@ def prompt_for_discriminated_union(
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
@ -150,17 +142,11 @@ def prompt_for_config(
for field_name, field in config_type.__fields__.items():
field_type = field.annotation
existing_value = (
getattr(existing_config, field_name) if existing_config else None
)
existing_value = getattr(existing_config, field_name) if existing_config else None
if existing_value:
default_value = existing_value
else:
default_value = (
field.default
if not isinstance(field.default, PydanticUndefinedType)
else None
)
default_value = field.default if not isinstance(field.default, PydanticUndefinedType) else None
is_required = field.is_required
# Skip fields with Literal type
@ -183,15 +169,11 @@ def prompt_for_config(
config_data[field_name] = validated_value
break
except KeyError:
log.error(
f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}"
)
log.error(f"Invalid choice. Please choose from: {', '.join(e.name for e in field_type)}")
continue
if is_discriminated_union(field):
config_data[field_name] = prompt_for_discriminated_union(
field_name, field, existing_value
)
config_data[field_name] = prompt_for_discriminated_union(field_name, field, existing_value)
continue
if is_optional(field_type) and can_recurse(get_non_none_type(field_type)):
@ -202,9 +184,7 @@ def prompt_for_config(
nested_type = get_non_none_type(field_type)
log.info(f"Entering sub-configuration for {field_name}:")
config_data[field_name] = prompt_for_config(nested_type, existing_value)
elif is_optional(field_type) and is_discriminated_union(
get_non_none_type(field_type)
):
elif is_optional(field_type) and is_discriminated_union(get_non_none_type(field_type)):
prompt = f"Do you want to configure {field_name}? (y/n): "
if input(prompt).lower() == "n":
config_data[field_name] = None
@ -260,16 +240,12 @@ def prompt_for_config(
try:
value = json.loads(user_input)
if not isinstance(value, list):
raise ValueError(
"Input must be a JSON-encoded list"
)
raise ValueError("Input must be a JSON-encoded list")
element_type = get_args(field_type)[0]
value = [element_type(item) for item in value]
except json.JSONDecodeError:
log.error(
'Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]'
)
log.error('Invalid JSON. Please enter a valid JSON-encoded list e.g., ["foo","bar"]')
continue
except ValueError as e:
log.error(f"{str(e)}")
@ -279,20 +255,14 @@ def prompt_for_config(
try:
value = json.loads(user_input)
if not isinstance(value, dict):
raise ValueError(
"Input must be a JSON-encoded dictionary"
)
raise ValueError("Input must be a JSON-encoded dictionary")
except json.JSONDecodeError:
log.error(
"Invalid JSON. Please enter a valid JSON-encoded dict."
)
log.error("Invalid JSON. Please enter a valid JSON-encoded dict.")
continue
# Convert the input to the correct type
elif inspect.isclass(field_type) and issubclass(
field_type, BaseModel
):
elif inspect.isclass(field_type) and issubclass(field_type, BaseModel):
# For nested BaseModels, we assume a dictionary-like string input
import ast
@ -301,16 +271,12 @@ def prompt_for_config(
value = field_type(user_input)
except ValueError:
log.error(
f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}"
)
log.error(f"Invalid input. Expected type: {getattr(field_type, '__name__', str(field_type))}")
continue
try:
# Validate the field using our manual validation function
validated_value = manually_validate_field(
config_type, field_name, value
)
validated_value = manually_validate_field(config_type, field_name, value)
config_data[field_name] = validated_value
break
except ValueError as e: