diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 4048e6da3..b1bfd03b6 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -194,11 +194,6 @@ async def run_rag(host: str, port: int): MemoryToolDefinition( max_tokens_in_context=2048, memory_bank_configs=[], - # memory_bank_configs=[ - # AgenticSystemVectorMemoryBankConfig( - # bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed", - # ) - # ], ), ] @@ -210,8 +205,9 @@ async def run_rag(host: str, port: int): await _run_agent(api, tool_definitions, user_prompts, attachments) -def main(host: str, port: int): - asyncio.run(run_rag(host, port)) +def main(host: str, port: int, rag: bool = False): + fn = run_rag if rag else run_main + asyncio.run(fn(host, port)) if __name__ == "__main__": diff --git a/llama_toolchain/agentic_system/meta_reference/__init__.py b/llama_toolchain/agentic_system/meta_reference/__init__.py index 11dc98333..acc1dcf0b 100644 --- a/llama_toolchain/agentic_system/meta_reference/__init__.py +++ b/llama_toolchain/agentic_system/meta_reference/__init__.py @@ -4,5 +4,27 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .agentic_system import get_provider_impl # noqa -from .config import MetaReferenceImplConfig # noqa +from typing import Dict + +from llama_toolchain.distribution.datatypes import Api, ProviderSpec + +from .config import MetaReferenceImplConfig + + +async def get_provider_impl( + config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] +): + from .agentic_system import MetaReferenceAgenticSystemImpl + + assert isinstance( + config, MetaReferenceImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = MetaReferenceAgenticSystemImpl( + config, + deps[Api.inference], + deps[Api.memory], + deps[Api.safety], + ) + await impl.initialize() + return impl diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 52ebd1ec7..4fa2aa584 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -8,9 +8,8 @@ import logging import os import uuid -from typing import AsyncGenerator, Dict +from typing import AsyncGenerator -from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import Inference from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety @@ -31,23 +30,6 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -async def get_provider_impl( - config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] -): - assert isinstance( - config, MetaReferenceImplConfig - ), f"Unexpected config type: {type(config)}" - - impl = MetaReferenceAgenticSystemImpl( - config, - deps[Api.inference], - deps[Api.memory], - deps[Api.safety], - ) - await impl.initialize() - return impl - - AGENT_INSTANCES_BY_ID = {} diff --git a/llama_toolchain/cli/api/build.py b/llama_toolchain/cli/api/build.py index 0f07f3a62..ae180fc20 100644 --- a/llama_toolchain/cli/api/build.py +++ b/llama_toolchain/cli/api/build.py @@ -10,7 +10,7 @@ import os from pydantic import BaseModel from datetime import datetime from enum import Enum -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import pkg_resources import yaml @@ -37,26 +37,17 @@ def get_dependencies( ) -> Dependencies: from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES - def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]: - if isinstance(provider, InlineProviderSpec): - return provider.pip_packages, provider.docker_image - else: - if provider.adapter: - return provider.adapter.pip_packages, None - return [], None - - pip_packages, docker_image = _deps(provider) + pip_packages = provider.pip_packages for dep in dependencies.values(): - dep_pip_packages, dep_docker_image = _deps(dep) - if docker_image and dep_docker_image: + if dep.docker_image: raise ValueError( "You can only have the root provider specify a docker image" ) - - pip_packages.extend(dep_pip_packages) + pip_packages.extend(dep.pip_packages) return Dependencies( - docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES + docker_image=provider.docker_image, + pip_packages=pip_packages + SERVER_DEPENDENCIES ) @@ -158,6 +149,7 @@ class ApiBuild(Subcommand): build_dir = BUILDS_BASE_DIR / args.api os.makedirs(build_dir, exist_ok=True) + # get these names straight. too confusing. provider_deps = parse_dependencies(args.dependencies or "", self.parser) dependencies = get_dependencies(providers[args.provider], provider_deps) @@ -167,7 +159,7 @@ class ApiBuild(Subcommand): api.value: { "provider_id": args.provider, }, - **{k: {"provider_id": v} for k, v in provider_deps.items()}, + **provider_deps, } with open(package_file, "w") as f: c = PackageConfig( diff --git a/llama_toolchain/cli/api/configure.py b/llama_toolchain/cli/api/configure.py index a3582f02e..3fb00383d 100644 --- a/llama_toolchain/cli/api/configure.py +++ b/llama_toolchain/cli/api/configure.py @@ -48,7 +48,10 @@ class ApiConfigure(Subcommand): ) def _run_api_configure_cmd(self, args: argparse.Namespace) -> None: - config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml" + name = args.name + if not name.endswith(".yaml"): + name += ".yaml" + config_file = BUILDS_BASE_DIR / args.api / name if not config_file.exists(): self.parser.error( f"Could not find {config_file}. Please run `llama api build` first" @@ -79,10 +82,19 @@ def configure_llama_provider(config_file: Path) -> None: ) provider_spec = providers[provider_id] - cprint(f"Configuring API surface: {api}", "white", attrs=["bold"]) + cprint( + f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"] + ) config_type = instantiate_class_type(provider_spec.config_class) + + try: + existing_provider_config = config_type(**stub_config) + except KeyError: + existing_provider_config = None + provider_config = prompt_for_config( config_type, + existing_provider_config, ) print("") diff --git a/llama_toolchain/cli/api/start.py b/llama_toolchain/cli/api/start.py index e10bafe5e..59b4d7dd4 100644 --- a/llama_toolchain/cli/api/start.py +++ b/llama_toolchain/cli/api/start.py @@ -29,10 +29,9 @@ class ApiStart(Subcommand): def _add_arguments(self): self.parser.add_argument( - "--yaml-config", + "yaml_config", type=str, help="Yaml config containing the API build configuration", - required=True, ) self.parser.add_argument( "--port", diff --git a/llama_toolchain/distribution/build_conda_env.sh b/llama_toolchain/distribution/build_conda_env.sh index ecdeaba1b..e5f0e4a65 100755 --- a/llama_toolchain/distribution/build_conda_env.sh +++ b/llama_toolchain/distribution/build_conda_env.sh @@ -69,7 +69,7 @@ ensure_conda_env_python310() { conda create -n "${env_name}" python="${python_version}" -y ENVNAME="${env_name}" - setup_cleanup_handlers + # setup_cleanup_handlers fi eval "$(conda shell.bash hook)" diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 51f582432..28d873bb5 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -36,16 +36,14 @@ class ProviderSpec(BaseModel): ..., description="Fully-qualified classname of the config for this provider", ) + api_dependencies: List[Api] = Field( + default_factory=list, + description="Higher-level API surfaces may depend on other providers to provide their functionality", + ) @json_schema_type class AdapterSpec(BaseModel): - """ - If some code is needed to convert the remote responses into Llama Stack compatible - API responses, specify the adapter here. If not specified, it indicates the remote - as being "Llama Stack compatible" - """ - adapter_id: str = Field( ..., description="Unique identifier for this adapter", @@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have: - `get_provider_impl(config, deps)`: returns the local implementation """, ) - api_dependencies: List[Api] = Field( - default_factory=list, - description="Higher-level API surfaces may depend on other providers to provide their functionality", - ) - is_adapter: bool = False class RemoteProviderConfig(BaseModel): @@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str: @json_schema_type class RemoteProviderSpec(ProviderSpec): - provider_id: str = "remote" - config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig" + adapter: Optional[AdapterSpec] = Field( + default=None, + description=""" +If some code is needed to convert the remote responses into Llama Stack compatible +API responses, specify the adapter here. If not specified, it indicates the remote +as being "Llama Stack compatible" +""", + ) @property def module(self) -> str: + if self.adapter: + return self.adapter.module return f"llama_toolchain.{self.api.value}.client" - -def remote_provider_spec(api: Api) -> RemoteProviderSpec: - return RemoteProviderSpec(api=api) + @property + def pip_packages(self) -> List[str]: + if self.adapter: + return self.adapter.pip_packages + return [] -# TODO: use computed_field to avoid this wrapper -# the @computed_field decorator -def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec: +# Can avoid this by using Pydantic computed_field +def remote_provider_spec( + api: Api, adapter: Optional[AdapterSpec] = None +) -> RemoteProviderSpec: config_class = ( adapter.config_class - if adapter.config_class + if adapter and adapter.config_class else "llama_toolchain.distribution.datatypes.RemoteProviderConfig" ) + provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" - return InlineProviderSpec( - api=api, - provider_id=remote_provider_id(adapter.adapter_id), - pip_packages=adapter.pip_packages, - module=adapter.module, - config_class=config_class, - is_adapter=True, + return RemoteProviderSpec( + api=api, provider_id=provider_id, config_class=config_class, adapter=adapter ) diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index c37494144..4c50189c0 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -22,6 +22,7 @@ from .datatypes import ( DistributionSpec, InlineProviderSpec, ProviderSpec, + remote_provider_spec, ) # These are the dependencies needed by the distribution server. @@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: a.provider_id: a for a in available_agentic_system_providers() } - return { + ret = { Api.inference: inference_providers_by_id, Api.safety: safety_providers_by_id, Api.agentic_system: agentic_system_providers_by_id, Api.memory: {a.provider_id: a for a in available_memory_providers()}, } + for k, v in ret.items(): + v["remote"] = remote_provider_spec(k) + return ret diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index 3135e2aff..a73c03592 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -8,7 +8,7 @@ import asyncio import importlib from typing import Any, Dict -from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig +from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec def instantiate_class_type(fully_qualified_name): @@ -26,16 +26,21 @@ def instantiate_provider( module = importlib.import_module(provider_spec.module) config_type = instantiate_class_type(provider_spec.config_class) - if isinstance(provider_spec, InlineProviderSpec): - if provider_spec.is_adapter: + if isinstance(provider_spec, RemoteProviderSpec): + if provider_spec.adapter: if not issubclass(config_type, RemoteProviderConfig): raise ValueError( f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig" ) - config = config_type(**provider_config) - - if isinstance(provider_spec, InlineProviderSpec): - args = [config, deps] + method = "get_adapter_impl" + else: + method = "get_client_impl" else: - args = [config] - return asyncio.run(module.get_provider_impl(*args)) + method = "get_provider_impl" + + config = config_type(**provider_config) + fn = getattr(module, method) + impl = asyncio.run(fn(config, deps)) + impl.__provider_spec__ = provider_spec + impl.__provider_config__ = config + return impl diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index efd761e0e..cf290b951 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated -from .datatypes import Api, ProviderSpec, RemoteProviderSpec +from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from .distribution import api_endpoints, api_providers from .dynamic import instantiate_provider @@ -230,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): visited.add(a.api) - if not isinstance(a, RemoteProviderSpec): - for api in a.api_dependencies: - if api not in visited: - dfs(by_id[api], visited, stack) + for api in a.api_dependencies: + if api not in visited: + dfs(by_id[api], visited, stack) stack.append(a.api) @@ -261,7 +260,10 @@ def resolve_impls( f"Could not find provider_spec config for {api}. Please add it to the config" ) - deps = {api: impls[api] for api in provider_spec.api_dependencies} + if isinstance(provider_spec, InlineProviderSpec): + deps = {api: impls[api] for api in provider_spec.api_dependencies} + else: + deps = {} provider_config = provider_configs[api.value] impl = instantiate_provider(provider_spec, provider_config, deps) impls[api] = impl @@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): and provider_spec.adapter is None ): for endpoint in endpoints: - url = impl.base_url + endpoint.route + url = impl.__provider_config__.url getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) diff --git a/llama_toolchain/inference/adapters/ollama/__init__.py b/llama_toolchain/inference/adapters/ollama/__init__.py index 14bf677cc..aa4f576d3 100644 --- a/llama_toolchain/inference/adapters/ollama/__init__.py +++ b/llama_toolchain/inference/adapters/ollama/__init__.py @@ -4,4 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .ollama import get_provider_impl # noqa +from llama_toolchain.distribution.datatypes import RemoteProviderConfig + + +async def get_adapter_impl(config: RemoteProviderConfig, _deps): + from .ollama import OllamaInferenceAdapter + + impl = OllamaInferenceAdapter(config.url) + await impl.initialize() + return impl diff --git a/llama_toolchain/inference/adapters/ollama/ollama.py b/llama_toolchain/inference/adapters/ollama/ollama.py index 30decd2cd..375257ea9 100644 --- a/llama_toolchain/inference/adapters/ollama/ollama.py +++ b/llama_toolchain/inference/adapters/ollama/ollama.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, AsyncGenerator +from typing import AsyncGenerator import httpx @@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from ollama import AsyncClient -from llama_toolchain.distribution.datatypes import RemoteProviderConfig -from llama_toolchain.inference.api import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionRequest, - Inference, - ToolCallDelta, - ToolCallParseStatus, -) +from llama_toolchain.inference.api import * # noqa: F403 from llama_toolchain.inference.prepare_messages import prepare_messages # TODO: Eventually this will move to the llama cli model list command # mapping of Model SKUs to ollama models OLLAMA_SUPPORTED_SKUS = { + # "Meta-Llama3.1-8B-Instruct": "llama3.1", "Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", } -async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: - impl = OllamaInferenceAdapter(config.url) - await impl.initialize() - return impl - - class OllamaInferenceAdapter(Inference): def __init__(self, url: str) -> None: self.url = url diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 17bd07406..0ae049bbc 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -6,7 +6,7 @@ import asyncio import json -from typing import AsyncGenerator +from typing import Any, AsyncGenerator import fire import httpx @@ -26,7 +26,7 @@ from .api import ( from .event_logger import EventLogger -async def get_provider_impl(config: RemoteProviderConfig) -> Inference: +async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: return InferenceClient(config.url) diff --git a/llama_toolchain/inference/meta_reference/__init__.py b/llama_toolchain/inference/meta_reference/__init__.py index 87a08816e..64d315e79 100644 --- a/llama_toolchain/inference/meta_reference/__init__.py +++ b/llama_toolchain/inference/meta_reference/__init__.py @@ -5,4 +5,15 @@ # the root directory of this source tree. from .config import MetaReferenceImplConfig # noqa -from .inference import get_provider_impl # noqa + + +async def get_provider_impl(config: MetaReferenceImplConfig, _deps): + from .inference import MetaReferenceInferenceImpl + + assert isinstance( + config, MetaReferenceImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = MetaReferenceInferenceImpl(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index 72cb105ff..187d5baae 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -6,12 +6,11 @@ import asyncio -from typing import AsyncIterator, Dict, Union +from typing import AsyncIterator, Union from llama_models.llama3.api.datatypes import StopReason from llama_models.sku_list import resolve_model -from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import ( ChatCompletionRequest, ChatCompletionResponse, @@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig from .model_parallel import LlamaModelParallelGenerator -async def get_provider_impl( - config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec] -): - assert isinstance( - config, MetaReferenceImplConfig - ), f"Unexpected config type: {type(config)}" - - impl = MetaReferenceInferenceImpl(config) - await impl.initialize() - return impl - - # there's a single model parallel process running serving the model. for now, # we don't support multiple concurrent requests to this process. SEMAPHORE = asyncio.Semaphore(1) diff --git a/llama_toolchain/inference/providers.py b/llama_toolchain/inference/providers.py index c9882bf98..ebf60cb5b 100644 --- a/llama_toolchain/inference/providers.py +++ b/llama_toolchain/inference/providers.py @@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]: module="llama_toolchain.inference.meta_reference", config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", ), - adapter_provider_spec( + remote_provider_spec( api=Api.inference, adapter=AdapterSpec( adapter_id="ollama", diff --git a/llama_toolchain/memory/client.py b/llama_toolchain/memory/client.py index abf6d2910..ecad9e46a 100644 --- a/llama_toolchain/memory/client.py +++ b/llama_toolchain/memory/client.py @@ -6,7 +6,7 @@ import asyncio -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import fire import httpx @@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig from .api import * # noqa: F403 -async def get_provider_impl(config: RemoteProviderConfig) -> Memory: +async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory: return MemoryClient(config.url) diff --git a/llama_toolchain/memory/meta_reference/faiss/__init__.py b/llama_toolchain/memory/meta_reference/faiss/__init__.py index 69a1a06b7..16c383be3 100644 --- a/llama_toolchain/memory/meta_reference/faiss/__init__.py +++ b/llama_toolchain/memory/meta_reference/faiss/__init__.py @@ -4,5 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import FaissImplConfig # noqa -from .faiss import get_provider_impl # noqa +from .config import FaissImplConfig + + +async def get_provider_impl(config: FaissImplConfig, _deps): + from .faiss import FaissMemoryImpl + + assert isinstance( + config, FaissImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = FaissMemoryImpl(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py index 6a168d330..422674939 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -15,21 +15,10 @@ import numpy as np from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer -from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.memory.api import * # noqa: F403 from .config import FaissImplConfig -async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]): - assert isinstance( - config, FaissImplConfig - ), f"Unexpected config type: {type(config)}" - - impl = FaissMemoryImpl(config) - await impl.initialize() - return impl - - async def content_from_doc(doc: MemoryBankDocument) -> str: if isinstance(doc.content, URL): async with httpx.AsyncClient() as client: diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 79a84eb3a..c05f59163 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -6,11 +6,12 @@ import asyncio +from typing import Any + import fire import httpx from llama_models.llama3.api.datatypes import UserMessage - from pydantic import BaseModel from termcolor import cprint @@ -19,7 +20,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig from .api import * # noqa: F403 -async def get_provider_impl(config: RemoteProviderConfig) -> Safety: +async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety: return SafetyClient(config.url) diff --git a/llama_toolchain/safety/meta_reference/__init__.py b/llama_toolchain/safety/meta_reference/__init__.py index f874f3dad..ad175ce46 100644 --- a/llama_toolchain/safety/meta_reference/__init__.py +++ b/llama_toolchain/safety/meta_reference/__init__.py @@ -4,5 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import SafetyConfig # noqa -from .safety import get_provider_impl # noqa +from .config import SafetyConfig + + +async def get_provider_impl(config: SafetyConfig, _deps): + from .safety import MetaReferenceSafetyImpl + + assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" + + impl = MetaReferenceSafetyImpl(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index 8f63b14f2..e71ac09a2 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -5,12 +5,10 @@ # the root directory of this source tree. import asyncio -from typing import Dict from llama_models.sku_list import resolve_model from llama_toolchain.common.model_utils import model_local_dir -from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.safety.api import * # noqa from .config import SafetyConfig @@ -25,14 +23,6 @@ from .shields import ( ) -async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]): - assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" - - impl = MetaReferenceSafetyImpl(config) - await impl.initialize() - return impl - - def resolve_and_get_path(model_name: str) -> str: model = resolve_model(model_name) assert model is not None, f"Could not resolve model {model_name}"