small fix

This commit is contained in:
Ashwin Bharambe 2024-11-07 16:01:36 -08:00
parent 984ba074e1
commit a7f728e41c
3 changed files with 15 additions and 11 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
import sys
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
@ -34,6 +33,10 @@ from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
class InvalidProviderError(Exception):
pass
def api_protocol_map() -> Dict[Api, Any]: def api_protocol_map() -> Dict[Api, Any]:
return { return {
Api.agents: Agents, Api.agents: Agents,
@ -105,7 +108,7 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"]) cprint(p.deprecation_error, "red", attrs=["bold"])
sys.exit(1) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
cprint( cprint(
@ -116,7 +119,7 @@ async def resolve_impls(
p.deps__ = [a.value for a in p.api_dependencies] p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec( spec = ProviderWithSpec(
spec=p, spec=p,
**(provider.dict()), **(provider.model_dump()),
) )
specs[provider.provider_id] = spec specs[provider.provider_id] = spec

View file

@ -9,6 +9,7 @@ import functools
import inspect import inspect
import json import json
import signal import signal
import sys
import traceback import traceback
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -41,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
) )
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
@ -282,7 +283,13 @@ def main(
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) try:
impls = asyncio.run(
resolve_impls(config, get_provider_registry(), dist_registry)
)
except InvalidProviderError:
sys.exit(1)
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])

View file

@ -57,9 +57,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S
], ],
module="llama_stack.providers.inline.safety.prompt_guard", module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
api_dependencies=[
Api.inference,
],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.safety, api=Api.safety,
@ -69,9 +66,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S
], ],
module="llama_stack.providers.inline.safety.code_scanner", module="llama_stack.providers.inline.safety.code_scanner",
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
api_dependencies=[
Api.inference,
],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.safety, api=Api.safety,