small fix

This commit is contained in:
Ashwin Bharambe 2024-11-07 16:01:36 -08:00
parent 984ba074e1
commit a7f728e41c
3 changed files with 15 additions and 11 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import importlib
import inspect
import sys
from typing import Any, Dict, List, Set
@ -34,6 +33,10 @@ from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
class InvalidProviderError(Exception):
pass
def api_protocol_map() -> Dict[Api, Any]:
return {
Api.agents: Agents,
@ -105,7 +108,7 @@ async def resolve_impls(
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
cprint(p.deprecation_error, "red", attrs=["bold"])
sys.exit(1)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
cprint(
@ -116,7 +119,7 @@ async def resolve_impls(
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.dict()),
**(provider.model_dump()),
)
specs[provider.provider_id] = spec

View file

@ -9,6 +9,7 @@ import functools
import inspect
import json
import signal
import sys
import traceback
from contextlib import asynccontextmanager
@ -41,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
from .endpoints import get_all_api_endpoints
@ -282,7 +283,13 @@ def main(
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
try:
impls = asyncio.run(
resolve_impls(config, get_provider_registry(), dist_registry)
)
except InvalidProviderError:
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -57,9 +57,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S
],
module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
api_dependencies=[
Api.inference,
],
),
InlineProviderSpec(
api=Api.safety,
@ -69,9 +66,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S
],
module="llama_stack.providers.inline.safety.code_scanner",
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
api_dependencies=[
Api.inference,
],
),
remote_provider_spec(
api=Api.safety,