Make Fireworks and Together into the Adapter format

This commit is contained in:
Ashwin Bharambe 2024-08-28 16:21:07 -07:00
parent a23a6ab95b
commit f1244f6d9e
10 changed files with 56 additions and 83 deletions

View file

@ -8,7 +8,7 @@ import asyncio
import importlib import importlib
from typing import Any, Dict from typing import Any, Dict
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec from .datatypes import ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name): def instantiate_class_type(fully_qualified_name):
@ -28,10 +28,6 @@ def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, RemoteProviderSpec): if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter: if provider_spec.adapter:
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
)
method = "get_adapter_impl" method = "get_adapter_impl"
else: else:
method = "get_client_impl" method = "get_client_impl"

View file

@ -39,21 +39,23 @@ def available_distribution_specs() -> List[DistributionSpec]:
}, },
), ),
DistributionSpec( DistributionSpec(
spec_id="remote-fireworks", distribution_id="local-plus-fireworks-inference",
description="Use Fireworks.ai for running LLM inference", description="Use Fireworks.ai for running LLM inference",
provider_specs={ providers={
Api.inference: providers[Api.inference]["fireworks"], Api.inference: remote_provider_id("fireworks"),
Api.safety: providers[Api.safety]["meta-reference"], Api.safety: "meta-reference",
Api.agentic_system: providers[Api.agentic_system]["meta-reference"], Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
}, },
), ),
DistributionSpec( DistributionSpec(
spec_id="remote-together", distribution_id="local-plus-together-inference",
description="Use Together.ai for running LLM inference", description="Use Together.ai for running LLM inference",
provider_specs={ providers={
Api.inference: providers[Api.inference]["together"], Api.inference: remote_provider_id("together"),
Api.safety: providers[Api.safety]["meta-reference"], Api.safety: "meta-reference",
Api.agentic_system: providers[Api.agentic_system]["meta-reference"], Api.agentic_system: "meta-reference",
Api.memory: "meta-reference-faiss",
}, },
), ),
] ]

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FireworksImplConfig
async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference:
from .fireworks import FireworksInferenceAdapter
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import AsyncGenerator, Dict from typing import AsyncGenerator
import httpx from fireworks.client import Fireworks
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import (
) )
from llama_models.llama3.api.tool_utils import ToolUtils from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from fireworks.client import Fireworks
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import FireworksImplConfig from .config import FireworksImplConfig
@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = {
} }
async def get_provider_impl( class FireworksInferenceAdapter(Inference):
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, FireworksImplConfig
), f"Unexpected config type: {type(config)}"
impl = FireworksInference(config)
await impl.initialize()
return impl
class FireworksInference(Inference):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
self.config = config self.config = config

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig
async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference:
from .together import TogetherInferenceAdapter
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import uuid import uuid
from typing import AsyncGenerator, Dict from typing import AsyncGenerator
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from .config import TogetherImplConfig from .config import TogetherImplConfig
@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = {
} }
async def get_provider_impl( class TogetherInferenceAdapter(Inference):
config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, TogetherImplConfig
), f"Unexpected config type: {type(config)}"
impl = TogetherInference(config)
await impl.initialize()
return impl
class TogetherInference(Inference):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
self.config = config self.config = config

View file

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import FireworksImplConfig # noqa
from .fireworks import get_provider_impl # noqa

View file

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig # noqa
from .together import get_provider_impl # noqa