bunch more work to make adapters work

This commit is contained in:
Ashwin Bharambe 2024-08-27 19:15:42 -07:00
parent 68f3db62e9
commit c4fe72c3a3
20 changed files with 461 additions and 173 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -4,5 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OllamaImplConfig # noqa
from .ollama import get_provider_impl # noqa
from .ollama import get_adapter_impl # noqa

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator
import httpx
@ -14,7 +14,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -27,7 +27,6 @@ from llama_toolchain.inference.api import (
ToolCallParseStatus,
)
from llama_toolchain.inference.prepare_messages import prepare_messages
from .config import OllamaImplConfig
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
@ -37,26 +36,21 @@ OLLAMA_SUPPORTED_SKUS = {
}
async def get_provider_impl(
config: OllamaImplConfig, _deps: Dict[Api, ProviderSpec]
) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
impl = OllamaInference(config)
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl
class OllamaInference(Inference):
def __init__(self, config: OllamaImplConfig) -> None:
self.config = config
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.config.url)
return AsyncClient(host=self.url)
async def initialize(self) -> None:
try:

View file

@ -13,6 +13,8 @@ import httpx
from pydantic import BaseModel
from termcolor import cprint
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from .api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -24,8 +26,8 @@ from .api import (
from .event_logger import EventLogger
async def get_client_impl(base_url: str):
return InferenceClient(base_url)
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
return InferenceClient(config.url)
def encodable_dict(d: BaseModel):
@ -34,7 +36,7 @@ def encodable_dict(d: BaseModel):
class InferenceClient(Inference):
def __init__(self, base_url: str):
print(f"Initializing client for {base_url}")
print(f"Inference passthrough to -> {base_url}")
self.base_url = base_url
async def initialize(self) -> None:

View file

@ -1,16 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class OllamaImplConfig(BaseModel):
url: str = Field(
default="http://localhost:11434",
description="The URL for the ollama server",
)

View file

@ -6,7 +6,7 @@
from typing import List
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_toolchain.distribution.datatypes import * # noqa: F403
def available_inference_providers() -> List[ProviderSpec]:
@ -27,13 +27,12 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
),
InlineProviderSpec(
remote_provider_spec(
api=Api.inference,
provider_id="meta-ollama",
pip_packages=[
"ollama",
],
module="llama_toolchain.inference.ollama",
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
adapter=AdapterSpec(
adapter_id="ollama",
pip_packages=[],
module="llama_toolchain.inference.adapters.ollama",
),
),
]