chore: more mypy fixes (#2029)

# What does this PR do?

Mainly tried to cover the entire llama_stack/apis directory, we only
have one left. Some excludes were just noop.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-06 18:52:31 +02:00 committed by GitHub
parent feb9eb8b0d
commit 1a529705da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 581 additions and 166 deletions

View file

@ -73,11 +73,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
@ -91,7 +87,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider_type in enumerate(plist):
if i >= 1:

View file

@ -30,7 +30,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
@ -216,7 +216,18 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
"yellow",
)
if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers)
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
),
)
print_pip_install_help(build_config)
else:
prefix = "!" if in_notebook() else ""
cprint(

View file

@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type
validator_class = spec.provider_data_validator

View file

@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response
message_placeholder.markdown(full_response.completion_message.content)
full_response = response.completion_message.content
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})