mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
Several smaller fixes to make adapters work
Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight
This commit is contained in:
parent
2a1552a5eb
commit
45987996c4
23 changed files with 164 additions and 160 deletions
|
|
@ -4,4 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .ollama import get_provider_impl # noqa
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -14,34 +14,18 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import resolve_model
|
||||
from ollama import AsyncClient
|
||||
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
|
||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
}
|
||||
|
||||
|
||||
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -26,7 +26,7 @@ from .api import (
|
|||
from .event_logger import EventLogger
|
||||
|
||||
|
||||
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
|
||||
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||
return InferenceClient(config.url)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,4 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import MetaReferenceImplConfig # noqa
|
||||
from .inference import get_provider_impl # noqa
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceImplConfig, _deps):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Dict, Union
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.inference.api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -27,18 +26,6 @@ from .config import MetaReferenceImplConfig
|
|||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def available_inference_providers() -> List[ProviderSpec]:
|
|||
module="llama_toolchain.inference.meta_reference",
|
||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||
),
|
||||
adapter_provider_spec(
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="ollama",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue