Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight
This commit is contained in:
Ashwin Bharambe 2024-08-28 09:42:08 -07:00
parent 2a1552a5eb
commit 45987996c4
23 changed files with 164 additions and 160 deletions

View file

@ -194,11 +194,6 @@ async def run_rag(host: str, port: int):
MemoryToolDefinition( MemoryToolDefinition(
max_tokens_in_context=2048, max_tokens_in_context=2048,
memory_bank_configs=[], memory_bank_configs=[],
# memory_bank_configs=[
# AgenticSystemVectorMemoryBankConfig(
# bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
# )
# ],
), ),
] ]
@ -210,8 +205,9 @@ async def run_rag(host: str, port: int):
await _run_agent(api, tool_definitions, user_prompts, attachments) await _run_agent(api, tool_definitions, user_prompts, attachments)
def main(host: str, port: int): def main(host: str, port: int, rag: bool = False):
asyncio.run(run_rag(host, port)) fn = run_rag if rag else run_main
asyncio.run(fn(host, port))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -4,5 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa from typing import Dict
from .config import MetaReferenceImplConfig # noqa
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
from .agentic_system import MetaReferenceAgenticSystemImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl

View file

@ -8,9 +8,8 @@
import logging import logging
import os import os
import uuid import uuid
from typing import AsyncGenerator, Dict from typing import AsyncGenerator
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api import Inference
from llama_toolchain.memory.api import Memory from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
@ -31,23 +30,6 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
config,
deps[Api.inference],
deps[Api.memory],
deps[Api.safety],
)
await impl.initialize()
return impl
AGENT_INSTANCES_BY_ID = {} AGENT_INSTANCES_BY_ID = {}

View file

@ -10,7 +10,7 @@ import os
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import pkg_resources import pkg_resources
import yaml import yaml
@ -37,26 +37,17 @@ def get_dependencies(
) -> Dependencies: ) -> Dependencies:
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]: pip_packages = provider.pip_packages
if isinstance(provider, InlineProviderSpec):
return provider.pip_packages, provider.docker_image
else:
if provider.adapter:
return provider.adapter.pip_packages, None
return [], None
pip_packages, docker_image = _deps(provider)
for dep in dependencies.values(): for dep in dependencies.values():
dep_pip_packages, dep_docker_image = _deps(dep) if dep.docker_image:
if docker_image and dep_docker_image:
raise ValueError( raise ValueError(
"You can only have the root provider specify a docker image" "You can only have the root provider specify a docker image"
) )
pip_packages.extend(dep.pip_packages)
pip_packages.extend(dep_pip_packages)
return Dependencies( return Dependencies(
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES docker_image=provider.docker_image,
pip_packages=pip_packages + SERVER_DEPENDENCIES
) )
@ -158,6 +149,7 @@ class ApiBuild(Subcommand):
build_dir = BUILDS_BASE_DIR / args.api build_dir = BUILDS_BASE_DIR / args.api
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
# get these names straight. too confusing.
provider_deps = parse_dependencies(args.dependencies or "", self.parser) provider_deps = parse_dependencies(args.dependencies or "", self.parser)
dependencies = get_dependencies(providers[args.provider], provider_deps) dependencies = get_dependencies(providers[args.provider], provider_deps)
@ -167,7 +159,7 @@ class ApiBuild(Subcommand):
api.value: { api.value: {
"provider_id": args.provider, "provider_id": args.provider,
}, },
**{k: {"provider_id": v} for k, v in provider_deps.items()}, **provider_deps,
} }
with open(package_file, "w") as f: with open(package_file, "w") as f:
c = PackageConfig( c = PackageConfig(

View file

@ -48,7 +48,10 @@ class ApiConfigure(Subcommand):
) )
def _run_api_configure_cmd(self, args: argparse.Namespace) -> None: def _run_api_configure_cmd(self, args: argparse.Namespace) -> None:
config_file = BUILDS_BASE_DIR / args.api / f"{args.name}.yaml" name = args.name
if not name.endswith(".yaml"):
name += ".yaml"
config_file = BUILDS_BASE_DIR / args.api / name
if not config_file.exists(): if not config_file.exists():
self.parser.error( self.parser.error(
f"Could not find {config_file}. Please run `llama api build` first" f"Could not find {config_file}. Please run `llama api build` first"
@ -79,10 +82,19 @@ def configure_llama_provider(config_file: Path) -> None:
) )
provider_spec = providers[provider_id] provider_spec = providers[provider_id]
cprint(f"Configuring API surface: {api}", "white", attrs=["bold"]) cprint(
f"Configuring API surface: {api} ({provider_id})", "white", attrs=["bold"]
)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
try:
existing_provider_config = config_type(**stub_config)
except KeyError:
existing_provider_config = None
provider_config = prompt_for_config( provider_config = prompt_for_config(
config_type, config_type,
existing_provider_config,
) )
print("") print("")

View file

@ -29,10 +29,9 @@ class ApiStart(Subcommand):
def _add_arguments(self): def _add_arguments(self):
self.parser.add_argument( self.parser.add_argument(
"--yaml-config", "yaml_config",
type=str, type=str,
help="Yaml config containing the API build configuration", help="Yaml config containing the API build configuration",
required=True,
) )
self.parser.add_argument( self.parser.add_argument(
"--port", "--port",

View file

@ -69,7 +69,7 @@ ensure_conda_env_python310() {
conda create -n "${env_name}" python="${python_version}" -y conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}" ENVNAME="${env_name}"
setup_cleanup_handlers # setup_cleanup_handlers
fi fi
eval "$(conda shell.bash hook)" eval "$(conda shell.bash hook)"

View file

@ -36,16 +36,14 @@ class ProviderSpec(BaseModel):
..., ...,
description="Fully-qualified classname of the config for this provider", description="Fully-qualified classname of the config for this provider",
) )
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type @json_schema_type
class AdapterSpec(BaseModel): class AdapterSpec(BaseModel):
"""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
"""
adapter_id: str = Field( adapter_id: str = Field(
..., ...,
description="Unique identifier for this adapter", description="Unique identifier for this adapter",
@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation - `get_provider_impl(config, deps)`: returns the local implementation
""", """,
) )
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
is_adapter: bool = False
class RemoteProviderConfig(BaseModel): class RemoteProviderConfig(BaseModel):
@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str:
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
provider_id: str = "remote" adapter: Optional[AdapterSpec] = Field(
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig" default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property @property
def module(self) -> str: def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client" return f"llama_toolchain.{self.api.value}.client"
@property
def remote_provider_spec(api: Api) -> RemoteProviderSpec: def pip_packages(self) -> List[str]:
return RemoteProviderSpec(api=api) if self.adapter:
return self.adapter.pip_packages
return []
# TODO: use computed_field to avoid this wrapper # Can avoid this by using Pydantic computed_field
# the @computed_field decorator def remote_provider_spec(
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec: api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = ( config_class = (
adapter.config_class adapter.config_class
if adapter.config_class if adapter and adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig" else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
) )
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return InlineProviderSpec( return RemoteProviderSpec(
api=api, api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
provider_id=remote_provider_id(adapter.adapter_id),
pip_packages=adapter.pip_packages,
module=adapter.module,
config_class=config_class,
is_adapter=True,
) )

View file

@ -22,6 +22,7 @@ from .datatypes import (
DistributionSpec, DistributionSpec,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec,
) )
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.
@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
a.provider_id: a for a in available_agentic_system_providers() a.provider_id: a for a in available_agentic_system_providers()
} }
return { ret = {
Api.inference: inference_providers_by_id, Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id, Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id, Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()}, Api.memory: {a.provider_id: a for a in available_memory_providers()},
} }
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -8,7 +8,7 @@ import asyncio
import importlib import importlib
from typing import Any, Dict from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -26,16 +26,21 @@ def instantiate_provider(
module = importlib.import_module(provider_spec.module) module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, InlineProviderSpec): if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.is_adapter: if provider_spec.adapter:
if not issubclass(config_type, RemoteProviderConfig): if not issubclass(config_type, RemoteProviderConfig):
raise ValueError( raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig" f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
) )
config = config_type(**provider_config) method = "get_adapter_impl"
else:
if isinstance(provider_spec, InlineProviderSpec): method = "get_client_impl"
args = [config, deps]
else: else:
args = [config] method = "get_provider_impl"
return asyncio.run(module.get_provider_impl(*args))
config = config_type(**provider_config)
fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from .datatypes import Api, ProviderSpec, RemoteProviderSpec from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider from .dynamic import instantiate_provider
@ -230,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]): def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api) visited.add(a.api)
if not isinstance(a, RemoteProviderSpec): for api in a.api_dependencies:
for api in a.api_dependencies: if api not in visited:
if api not in visited: dfs(by_id[api], visited, stack)
dfs(by_id[api], visited, stack)
stack.append(a.api) stack.append(a.api)
@ -261,7 +260,10 @@ def resolve_impls(
f"Could not find provider_spec config for {api}. Please add it to the config" f"Could not find provider_spec config for {api}. Please add it to the config"
) )
deps = {api: impls[api] for api in provider_spec.api_dependencies} if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value] provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps) impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl impls[api] = impl
@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
and provider_spec.adapter is None and provider_spec.adapter is None
): ):
for endpoint in endpoints: for endpoint in endpoints:
url = impl.base_url + endpoint.route url = impl.__provider_config__.url
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url) create_dynamic_passthrough(url)
) )

View file

@ -4,4 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .ollama import get_provider_impl # noqa from llama_toolchain.distribution.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator from typing import AsyncGenerator
import httpx import httpx
@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from ollama import AsyncClient from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import RemoteProviderConfig from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_toolchain.inference.prepare_messages import prepare_messages from llama_toolchain.inference.prepare_messages import prepare_messages
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
} }
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl
class OllamaInferenceAdapter(Inference): class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None: def __init__(self, url: str) -> None:
self.url = url self.url = url

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
import json import json
from typing import AsyncGenerator from typing import Any, AsyncGenerator
import fire import fire
import httpx import httpx
@ -26,7 +26,7 @@ from .api import (
from .event_logger import EventLogger from .event_logger import EventLogger
async def get_provider_impl(config: RemoteProviderConfig) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url) return InferenceClient(config.url)

View file

@ -5,4 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa from .config import MetaReferenceImplConfig # noqa
from .inference import get_provider_impl # noqa
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -6,12 +6,11 @@
import asyncio import asyncio
from typing import AsyncIterator, Dict, Union from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import ( from llama_toolchain.inference.api import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)

View file

@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference", module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig", config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
), ),
adapter_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="ollama", adapter_id="ollama",

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
async def get_provider_impl(config: RemoteProviderConfig) -> Memory: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Memory:
return MemoryClient(config.url) return MemoryClient(config.url)

View file

@ -4,5 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import FaissImplConfig # noqa from .config import FaissImplConfig
from .faiss import get_provider_impl # noqa
async def get_provider_impl(config: FaissImplConfig, _deps):
from .faiss import FaissMemoryImpl
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl

View file

@ -15,21 +15,10 @@ import numpy as np
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.memory.api import * # noqa: F403 from llama_toolchain.memory.api import * # noqa: F403
from .config import FaissImplConfig from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(
config, FaissImplConfig
), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config)
await impl.initialize()
return impl
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:

View file

@ -6,11 +6,12 @@
import asyncio import asyncio
from typing import Any
import fire import fire
import httpx import httpx
from llama_models.llama3.api.datatypes import UserMessage from llama_models.llama3.api.datatypes import UserMessage
from pydantic import BaseModel from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
@ -19,7 +20,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import * # noqa: F403 from .api import * # noqa: F403
async def get_provider_impl(config: RemoteProviderConfig) -> Safety: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
return SafetyClient(config.url) return SafetyClient(config.url)

View file

@ -4,5 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .config import SafetyConfig # noqa from .config import SafetyConfig
from .safety import get_provider_impl # noqa
async def get_provider_impl(config: SafetyConfig, _deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -5,12 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
from typing import Dict
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.safety.api import * # noqa from llama_toolchain.safety.api import * # noqa
from .config import SafetyConfig from .config import SafetyConfig
@ -25,14 +23,6 @@ from .shields import (
) )
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl
def resolve_and_get_path(model_name: str) -> str: def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name) model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}" assert model is not None, f"Could not resolve model {model_name}"