Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight
This commit is contained in:
Ashwin Bharambe 2024-08-28 09:42:08 -07:00
parent 2a1552a5eb
commit 45987996c4
23 changed files with 164 additions and 160 deletions

View file

@ -4,4 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .ollama import get_provider_impl # noqa
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator
from typing import AsyncGenerator
import httpx
@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from ollama import AsyncClient
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.inference.prepare_messages import prepare_messages
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
}
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
return impl
class OllamaInferenceAdapter(Inference):
def __init__(self, url: str) -> None:
self.url = url

View file

@ -6,7 +6,7 @@
import asyncio
import json
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import fire
import httpx
@ -26,7 +26,7 @@ from .api import (
from .event_logger import EventLogger
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url)

View file

@ -5,4 +5,15 @@
# the root directory of this source tree.
from .config import MetaReferenceImplConfig # noqa
from .inference import get_provider_impl # noqa
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
from .inference import MetaReferenceInferenceImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -6,12 +6,11 @@
import asyncio
from typing import AsyncIterator, Dict, Union
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig
from .model_parallel import LlamaModelParallelGenerator
async def get_provider_impl(
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceInferenceImpl(config)
await impl.initialize()
return impl
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)

View file

@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]:
module="llama_toolchain.inference.meta_reference",
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
),
adapter_provider_spec(
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="ollama",