mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
bunch more work to make adapters work
This commit is contained in:
parent
68f3db62e9
commit
c4fe72c3a3
20 changed files with 461 additions and 173 deletions
5
llama_toolchain/inference/adapters/__init__.py
Normal file
5
llama_toolchain/inference/adapters/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -4,5 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OllamaImplConfig # noqa
|
||||
from .ollama import get_provider_impl # noqa
|
||||
from .ollama import get_adapter_impl # noqa
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator, Dict
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import resolve_model
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -27,7 +27,6 @@ from llama_toolchain.inference.api import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
|
@ -37,26 +36,21 @@ OLLAMA_SUPPORTED_SKUS = {
|
|||
}
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
) -> Inference:
|
||||
assert isinstance(
|
||||
config, OllamaImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = OllamaInference(config)
|
||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class OllamaInference(Inference):
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.config = config
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.config.url)
|
||||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
|
@ -13,6 +13,8 @@ import httpx
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from .api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -24,8 +26,8 @@ from .api import (
|
|||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_client_impl(base_url: str):
|
||||
return InferenceClient(base_url)
|
||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
||||
return InferenceClient(config.url)
|
||||
|
||||
|
||||
def encodable_dict(d: BaseModel):
|
||||
|
@ -34,7 +36,7 @@ def encodable_dict(d: BaseModel):
|
|||
|
||||
class InferenceClient(Inference):
|
||||
def __init__(self, base_url: str):
|
||||
print(f"Initializing client for {base_url}")
|
||||
print(f"Inference passthrough to -> {base_url}")
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434",
|
||||
description="The URL for the ollama server",
|
||||
)
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_inference_providers() -> List[ProviderSpec]:
|
||||
|
@ -27,13 +27,12 @@ def available_inference_providers() -> List[ProviderSpec]:
|
|||
module="llama_toolchain.inference.meta_reference",
|
||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
provider_id="meta-ollama",
|
||||
pip_packages=[
|
||||
"ollama",
|
||||
],
|
||||
module="llama_toolchain.inference.ollama",
|
||||
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="ollama",
|
||||
pip_packages=[],
|
||||
module="llama_toolchain.inference.adapters.ollama",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue