mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Make Fireworks and Together into the Adapter format
This commit is contained in:
parent
a23a6ab95b
commit
f1244f6d9e
10 changed files with 56 additions and 83 deletions
18
llama_toolchain/inference/adapters/fireworks/__init__.py
Normal file
18
llama_toolchain/inference/adapters/fireworks/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference:
|
||||
from .fireworks import FireworksInferenceAdapter
|
||||
|
||||
assert isinstance(
|
||||
config, FireworksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -5,9 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import (
|
|||
)
|
||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||
from llama_models.sku_list import resolve_model
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
) -> Inference:
|
||||
assert isinstance(
|
||||
config, FireworksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInference(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class FireworksInference(Inference):
|
||||
class FireworksInferenceAdapter(Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
self.config = config
|
||||
|
18
llama_toolchain/inference/adapters/together/__init__.py
Normal file
18
llama_toolchain/inference/adapters/together/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference:
|
||||
from .together import TogetherInferenceAdapter
|
||||
|
||||
assert isinstance(
|
||||
config, TogetherImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils
|
|||
from llama_models.sku_list import resolve_model
|
||||
from together import Together
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
) -> Inference:
|
||||
assert isinstance(
|
||||
config, TogetherImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherInference(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class TogetherInference(Inference):
|
||||
class TogetherInferenceAdapter(Inference):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
self.config = config
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import FireworksImplConfig # noqa
|
||||
from .fireworks import get_provider_impl # noqa
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TogetherImplConfig # noqa
|
||||
from .together import get_provider_impl # noqa
|
Loading…
Add table
Add a link
Reference in a new issue