mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
getting closer to a distro definition, distro install + configure works
This commit is contained in:
parent
dac2b5a1ed
commit
041cafbee3
11 changed files with 471 additions and 130 deletions
33
llama_toolchain/inference/adapters.py
Normal file
33
llama_toolchain/inference/adapters.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
|
||||
|
||||
|
||||
def available_inference_adapters() -> List[Adapter]:
|
||||
return [
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.inference,
|
||||
adapter_id="meta-reference",
|
||||
pip_packages=[
|
||||
"torch",
|
||||
"zmq",
|
||||
],
|
||||
module="llama_toolchain.inference.inference",
|
||||
config_class="llama_toolchain.inference.inference.InlineImplConfig",
|
||||
),
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.inference,
|
||||
adapter_id="meta-ollama",
|
||||
pip_packages=[
|
||||
"ollama",
|
||||
],
|
||||
module="llama_toolchain.inference.ollama",
|
||||
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
|
||||
),
|
||||
]
|
|
@ -16,8 +16,8 @@ from .api.datatypes import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from .api.endpoints import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
|
@ -25,6 +25,13 @@ from .api.endpoints import (
|
|||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
def get_adapter_impl(config: InlineImplConfig) -> Inference:
|
||||
assert isinstance(
|
||||
config, InlineImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
return InferenceImpl(config)
|
||||
|
||||
|
||||
class InferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: InlineImplConfig) -> None:
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
import httpx
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from ollama import AsyncClient
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -14,6 +19,8 @@ from llama_models.llama3_1.api.datatypes import (
|
|||
)
|
||||
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from .api.config import OllamaImplConfig
|
||||
from .api.datatypes import (
|
||||
ChatCompletionResponseEvent,
|
||||
|
@ -22,14 +29,20 @@ from .api.datatypes import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from .api.endpoints import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
)
|
||||
|
||||
|
||||
def get_adapter_impl(config: OllamaImplConfig) -> Inference:
|
||||
assert isinstance(
|
||||
config, OllamaImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
return OllamaInference(config)
|
||||
|
||||
|
||||
class OllamaInference(Inference):
|
||||
|
||||
|
@ -41,9 +54,13 @@ class OllamaInference(Inference):
|
|||
self.client = AsyncClient(host=self.config.url)
|
||||
try:
|
||||
status = await self.client.pull(self.model)
|
||||
assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama"
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
except httpx.ConnectError:
|
||||
print("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
print(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
raise
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
@ -55,9 +72,7 @@ class OllamaInference(Inference):
|
|||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
||||
ollama_messages = []
|
||||
for message in messages:
|
||||
ollama_messages.append(
|
||||
{"role": message.role, "content": message.content}
|
||||
)
|
||||
ollama_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
return ollama_messages
|
||||
|
||||
|
@ -67,16 +82,16 @@ class OllamaInference(Inference):
|
|||
model=self.model,
|
||||
messages=self._messages_to_ollama_messages(request.messages),
|
||||
stream=False,
|
||||
#TODO: add support for options like temp, top_p, max_seq_length, etc
|
||||
# TODO: add support for options like temp, top_p, max_seq_length, etc
|
||||
)
|
||||
if r['done']:
|
||||
if r['done_reason'] == 'stop':
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r['done_reason'] == 'length':
|
||||
elif r["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = decode_assistant_message_from_content(
|
||||
r['message']['content'],
|
||||
r["message"]["content"],
|
||||
stop_reason,
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
|
@ -94,7 +109,7 @@ class OllamaInference(Inference):
|
|||
stream = await self.client.chat(
|
||||
model=self.model,
|
||||
messages=self._messages_to_ollama_messages(request.messages),
|
||||
stream=True
|
||||
stream=True,
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
|
@ -103,14 +118,14 @@ class OllamaInference(Inference):
|
|||
|
||||
async for chunk in stream:
|
||||
# check if ollama is done
|
||||
if chunk['done']:
|
||||
if chunk['done_reason'] == 'stop':
|
||||
if chunk["done"]:
|
||||
if chunk["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif chunk['done_reason'] == 'length':
|
||||
elif chunk["done_reason"] == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk['message']['content']
|
||||
text = chunk["message"]["content"]
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
|
@ -197,7 +212,7 @@ class OllamaInference(Inference):
|
|||
)
|
||||
|
||||
|
||||
#TODO: Consolidate this with impl in llama-models
|
||||
# TODO: Consolidate this with impl in llama-models
|
||||
def decode_assistant_message_from_content(
|
||||
content: str,
|
||||
stop_reason: StopReason,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue