From a7f728e41c018fe62f84aae5c3711c450022fbb0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 7 Nov 2024 16:01:36 -0800 Subject: [PATCH] small fix --- llama_stack/distribution/resolver.py | 9 ++++++--- llama_stack/distribution/server/server.py | 11 +++++++++-- llama_stack/providers/registry/safety.py | 6 ------ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b689b00c9..4e7fa0102 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import importlib import inspect -import sys from typing import Any, Dict, List, Set @@ -34,6 +33,10 @@ from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type +class InvalidProviderError(Exception): + pass + + def api_protocol_map() -> Dict[Api, Any]: return { Api.agents: Agents, @@ -105,7 +108,7 @@ async def resolve_impls( p = provider_registry[api][provider.provider_type] if p.deprecation_error: cprint(p.deprecation_error, "red", attrs=["bold"]) - sys.exit(1) + raise InvalidProviderError(p.deprecation_error) elif p.deprecation_warning: cprint( @@ -116,7 +119,7 @@ async def resolve_impls( p.deps__ = [a.value for a in p.api_dependencies] spec = ProviderWithSpec( spec=p, - **(provider.dict()), + **(provider.model_dump()), ) specs[provider.provider_id] = spec diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 143813780..9193583e1 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -9,6 +9,7 @@ import functools import inspect import json import signal +import sys import traceback from contextlib import asynccontextmanager @@ -41,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls from .endpoints import get_all_api_endpoints @@ -282,7 +283,13 @@ def main( dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) - impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) + try: + impls = asyncio.run( + resolve_impls(config, get_provider_registry(), dist_registry) + ) + except InvalidProviderError: + sys.exit(1) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 3479671b2..63676c4f1 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -57,9 +57,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S ], module="llama_stack.providers.inline.safety.prompt_guard", config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig", - api_dependencies=[ - Api.inference, - ], ), InlineProviderSpec( api=Api.safety, @@ -69,9 +66,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S ], module="llama_stack.providers.inline.safety.code_scanner", config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", - api_dependencies=[ - Api.inference, - ], ), remote_provider_spec( api=Api.safety,