diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index a73c03592..adb9b5dac 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -8,7 +8,7 @@ import asyncio import importlib from typing import Any, Dict -from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec +from .datatypes import ProviderSpec, RemoteProviderSpec def instantiate_class_type(fully_qualified_name): @@ -28,10 +28,6 @@ def instantiate_provider( config_type = instantiate_class_type(provider_spec.config_class) if isinstance(provider_spec, RemoteProviderSpec): if provider_spec.adapter: - if not issubclass(config_type, RemoteProviderConfig): - raise ValueError( - f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig" - ) method = "get_adapter_impl" else: method = "get_client_impl" diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 9acc4a18c..e134fdab6 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -39,21 +39,23 @@ def available_distribution_specs() -> List[DistributionSpec]: }, ), DistributionSpec( - spec_id="remote-fireworks", + distribution_id="local-plus-fireworks-inference", description="Use Fireworks.ai for running LLM inference", - provider_specs={ - Api.inference: providers[Api.inference]["fireworks"], - Api.safety: providers[Api.safety]["meta-reference"], - Api.agentic_system: providers[Api.agentic_system]["meta-reference"], + providers={ + Api.inference: remote_provider_id("fireworks"), + Api.safety: "meta-reference", + Api.agentic_system: "meta-reference", + Api.memory: "meta-reference-faiss", }, ), DistributionSpec( - spec_id="remote-together", + distribution_id="local-plus-together-inference", description="Use Together.ai for running LLM inference", - provider_specs={ - Api.inference: providers[Api.inference]["together"], - Api.safety: providers[Api.safety]["meta-reference"], - Api.agentic_system: providers[Api.agentic_system]["meta-reference"], + providers={ + Api.inference: remote_provider_id("together"), + Api.safety: "meta-reference", + Api.agentic_system: "meta-reference", + Api.memory: "meta-reference-faiss", }, ), ] diff --git a/llama_toolchain/inference/adapters/fireworks/__init__.py b/llama_toolchain/inference/adapters/fireworks/__init__.py new file mode 100644 index 000000000..6de34833f --- /dev/null +++ b/llama_toolchain/inference/adapters/fireworks/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import FireworksImplConfig + + +async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference: + from .fireworks import FireworksInferenceAdapter + + assert isinstance( + config, FireworksImplConfig + ), f"Unexpected config type: {type(config)}" + impl = FireworksInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/inference/fireworks/config.py b/llama_toolchain/inference/adapters/fireworks/config.py similarity index 100% rename from llama_toolchain/inference/fireworks/config.py rename to llama_toolchain/inference/adapters/fireworks/config.py diff --git a/llama_toolchain/inference/fireworks/fireworks.py b/llama_toolchain/inference/adapters/fireworks/fireworks.py similarity index 93% rename from llama_toolchain/inference/fireworks/fireworks.py rename to llama_toolchain/inference/adapters/fireworks/fireworks.py index 2e08cc042..c9d6e38fd 100644 --- a/llama_toolchain/inference/fireworks/fireworks.py +++ b/llama_toolchain/inference/adapters/fireworks/fireworks.py @@ -5,9 +5,9 @@ # the root directory of this source tree. import uuid -from typing import AsyncGenerator, Dict +from typing import AsyncGenerator -import httpx +from fireworks.client import Fireworks from llama_models.llama3.api.datatypes import ( BuiltinTool, @@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import ( ) from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.sku_list import resolve_model -from fireworks.client import Fireworks -from llama_toolchain.distribution.datatypes import Api, ProviderSpec -from llama_toolchain.inference.api import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionRequest, - Inference, - ToolCallDelta, - ToolCallParseStatus, -) +from llama_toolchain.inference.api import * # noqa: F403 from .config import FireworksImplConfig @@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = { } -async def get_provider_impl( - config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec] -) -> Inference: - assert isinstance( - config, FireworksImplConfig - ), f"Unexpected config type: {type(config)}" - impl = FireworksInference(config) - await impl.initialize() - return impl - - -class FireworksInference(Inference): +class FireworksInferenceAdapter(Inference): def __init__(self, config: FireworksImplConfig) -> None: self.config = config diff --git a/llama_toolchain/inference/adapters/together/__init__.py b/llama_toolchain/inference/adapters/together/__init__.py new file mode 100644 index 000000000..ad8bc2ac1 --- /dev/null +++ b/llama_toolchain/inference/adapters/together/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import TogetherImplConfig + + +async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference: + from .together import TogetherInferenceAdapter + + assert isinstance( + config, TogetherImplConfig + ), f"Unexpected config type: {type(config)}" + impl = TogetherInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_toolchain/inference/together/config.py b/llama_toolchain/inference/adapters/together/config.py similarity index 100% rename from llama_toolchain/inference/together/config.py rename to llama_toolchain/inference/adapters/together/config.py diff --git a/llama_toolchain/inference/together/together.py b/llama_toolchain/inference/adapters/together/together.py similarity index 93% rename from llama_toolchain/inference/together/together.py rename to llama_toolchain/inference/adapters/together/together.py index e7ccf623e..b8f63df65 100644 --- a/llama_toolchain/inference/together/together.py +++ b/llama_toolchain/inference/adapters/together/together.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import uuid -from typing import AsyncGenerator, Dict +from typing import AsyncGenerator from llama_models.llama3.api.datatypes import ( BuiltinTool, @@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.sku_list import resolve_model from together import Together -from llama_toolchain.distribution.datatypes import Api, ProviderSpec -from llama_toolchain.inference.api import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionRequest, - Inference, - ToolCallDelta, - ToolCallParseStatus, -) +from llama_toolchain.inference.api import * # noqa: F403 from .config import TogetherImplConfig @@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = { } -async def get_provider_impl( - config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec] -) -> Inference: - assert isinstance( - config, TogetherImplConfig - ), f"Unexpected config type: {type(config)}" - impl = TogetherInference(config) - await impl.initialize() - return impl - - -class TogetherInference(Inference): +class TogetherInferenceAdapter(Inference): def __init__(self, config: TogetherImplConfig) -> None: self.config = config diff --git a/llama_toolchain/inference/fireworks/__init__.py b/llama_toolchain/inference/fireworks/__init__.py deleted file mode 100644 index baeb758ad..000000000 --- a/llama_toolchain/inference/fireworks/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import FireworksImplConfig # noqa -from .fireworks import get_provider_impl # noqa diff --git a/llama_toolchain/inference/together/__init__.py b/llama_toolchain/inference/together/__init__.py deleted file mode 100644 index 5be75efcc..000000000 --- a/llama_toolchain/inference/together/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import TogetherImplConfig # noqa -from .together import get_provider_impl # noqa