mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Make Fireworks and Together into the Adapter format
This commit is contained in:
parent
a23a6ab95b
commit
f1244f6d9e
10 changed files with 56 additions and 83 deletions
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
|
from .datatypes import ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -28,10 +28,6 @@ def instantiate_provider(
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
if provider_spec.adapter:
|
if provider_spec.adapter:
|
||||||
if not issubclass(config_type, RemoteProviderConfig):
|
|
||||||
raise ValueError(
|
|
||||||
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
|
||||||
)
|
|
||||||
method = "get_adapter_impl"
|
method = "get_adapter_impl"
|
||||||
else:
|
else:
|
||||||
method = "get_client_impl"
|
method = "get_client_impl"
|
||||||
|
|
|
@ -39,21 +39,23 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
spec_id="remote-fireworks",
|
distribution_id="local-plus-fireworks-inference",
|
||||||
description="Use Fireworks.ai for running LLM inference",
|
description="Use Fireworks.ai for running LLM inference",
|
||||||
provider_specs={
|
providers={
|
||||||
Api.inference: providers[Api.inference]["fireworks"],
|
Api.inference: remote_provider_id("fireworks"),
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: "meta-reference",
|
||||||
|
Api.memory: "meta-reference-faiss",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
spec_id="remote-together",
|
distribution_id="local-plus-together-inference",
|
||||||
description="Use Together.ai for running LLM inference",
|
description="Use Together.ai for running LLM inference",
|
||||||
provider_specs={
|
providers={
|
||||||
Api.inference: providers[Api.inference]["together"],
|
Api.inference: remote_provider_id("together"),
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
Api.safety: "meta-reference",
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: "meta-reference",
|
||||||
|
Api.memory: "meta-reference-faiss",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
18
llama_toolchain/inference/adapters/fireworks/__init__.py
Normal file
18
llama_toolchain/inference/adapters/fireworks/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: FireworksImplConfig, _deps) -> Inference:
|
||||||
|
from .fireworks import FireworksInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, FireworksImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = FireworksInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -5,9 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -18,20 +18,8 @@ from llama_models.llama3.api.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from fireworks.client import Fireworks
|
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
@ -42,18 +30,7 @@ FIREWORKS_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
class FireworksInferenceAdapter(Inference):
|
||||||
config: FireworksImplConfig, _deps: Dict[Api, ProviderSpec]
|
|
||||||
) -> Inference:
|
|
||||||
assert isinstance(
|
|
||||||
config, FireworksImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
impl = FireworksInference(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class FireworksInference(Inference):
|
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
18
llama_toolchain/inference/adapters/together/__init__.py
Normal file
18
llama_toolchain/inference/adapters/together/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: TogetherImplConfig, _deps) -> Inference:
|
||||||
|
from .together import TogetherInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, TogetherImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = TogetherInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, Dict
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -18,18 +18,7 @@ from llama_models.llama3.api.tool_utils import ToolUtils
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import (
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEvent,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionRequest,
|
|
||||||
Inference,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
@ -40,18 +29,7 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
class TogetherInferenceAdapter(Inference):
|
||||||
config: TogetherImplConfig, _deps: Dict[Api, ProviderSpec]
|
|
||||||
) -> Inference:
|
|
||||||
assert isinstance(
|
|
||||||
config, TogetherImplConfig
|
|
||||||
), f"Unexpected config type: {type(config)}"
|
|
||||||
impl = TogetherInference(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
class TogetherInference(Inference):
|
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .config import FireworksImplConfig # noqa
|
|
||||||
from .fireworks import get_provider_impl # noqa
|
|
|
@ -1,8 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .config import TogetherImplConfig # noqa
|
|
||||||
from .together import get_provider_impl # noqa
|
|
Loading…
Add table
Add a link
Reference in a new issue