mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
small fix
This commit is contained in:
parent
984ba074e1
commit
a7f728e41c
3 changed files with 15 additions and 11 deletions
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
@ -34,6 +33,10 @@ from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidProviderError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def api_protocol_map() -> Dict[Api, Any]:
|
def api_protocol_map() -> Dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
|
@ -105,7 +108,7 @@ async def resolve_impls(
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
cprint(p.deprecation_error, "red", attrs=["bold"])
|
cprint(p.deprecation_error, "red", attrs=["bold"])
|
||||||
sys.exit(1)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
elif p.deprecation_warning:
|
elif p.deprecation_warning:
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -116,7 +119,7 @@ async def resolve_impls(
|
||||||
p.deps__ = [a.value for a in p.api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
spec = ProviderWithSpec(
|
spec = ProviderWithSpec(
|
||||||
spec=p,
|
spec=p,
|
||||||
**(provider.dict()),
|
**(provider.model_dump()),
|
||||||
)
|
)
|
||||||
specs[provider.provider_id] = spec
|
specs[provider.provider_id] = spec
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
@ -41,7 +42,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
@ -282,7 +283,13 @@ def main(
|
||||||
|
|
||||||
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
try:
|
||||||
|
impls = asyncio.run(
|
||||||
|
resolve_impls(config, get_provider_registry(), dist_registry)
|
||||||
|
)
|
||||||
|
except InvalidProviderError:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -57,9 +57,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.safety.prompt_guard",
|
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||||
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
|
||||||
api_dependencies=[
|
|
||||||
Api.inference,
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
@ -69,9 +66,6 @@ Provider `meta-reference` for API `safety` does not work with the latest Llama S
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.safety.code_scanner",
|
module="llama_stack.providers.inline.safety.code_scanner",
|
||||||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||||
api_dependencies=[
|
|
||||||
Api.inference,
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue