getting closer to a distro definition, distro install + configure works

This commit is contained in:
Ashwin Bharambe 2024-08-01 22:59:11 -07:00
parent dac2b5a1ed
commit 041cafbee3
11 changed files with 471 additions and 130 deletions

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
def available_inference_adapters() -> List[Adapter]:
return [
SourceAdapter(
api_surface=ApiSurface.inference,
adapter_id="meta-reference",
pip_packages=[
"torch",
"zmq",
],
module="llama_toolchain.inference.inference",
config_class="llama_toolchain.inference.inference.InlineImplConfig",
),
SourceAdapter(
api_surface=ApiSurface.inference,
adapter_id="meta-ollama",
pip_packages=[
"ollama",
],
module="llama_toolchain.inference.ollama",
config_class="llama_toolchain.inference.ollama.OllamaImplConfig",
),
]

View file

@ -16,8 +16,8 @@ from .api.datatypes import (
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
@ -25,6 +25,13 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator
def get_adapter_impl(config: InlineImplConfig) -> Inference:
assert isinstance(
config, InlineImplConfig
), f"Unexpected config type: {type(config)}"
return InferenceImpl(config)
class InferenceImpl(Inference):
def __init__(self, config: InlineImplConfig) -> None:

View file

@ -1,9 +1,14 @@
import httpx
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import AsyncGenerator
from ollama import AsyncClient
import httpx
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
@ -14,6 +19,8 @@ from llama_models.llama3_1.api.datatypes import (
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from ollama import AsyncClient
from .api.config import OllamaImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
@ -22,14 +29,20 @@ from .api.datatypes import (
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
)
def get_adapter_impl(config: OllamaImplConfig) -> Inference:
assert isinstance(
config, OllamaImplConfig
), f"Unexpected config type: {type(config)}"
return OllamaInference(config)
class OllamaInference(Inference):
@ -41,9 +54,13 @@ class OllamaInference(Inference):
self.client = AsyncClient(host=self.config.url)
try:
status = await self.client.pull(self.model)
assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama"
assert (
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
except httpx.ConnectError:
print("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
print(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
)
raise
async def shutdown(self) -> None:
@ -55,9 +72,7 @@ class OllamaInference(Inference):
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
ollama_messages = []
for message in messages:
ollama_messages.append(
{"role": message.role, "content": message.content}
)
ollama_messages.append({"role": message.role, "content": message.content})
return ollama_messages
@ -67,16 +82,16 @@ class OllamaInference(Inference):
model=self.model,
messages=self._messages_to_ollama_messages(request.messages),
stream=False,
#TODO: add support for options like temp, top_p, max_seq_length, etc
# TODO: add support for options like temp, top_p, max_seq_length, etc
)
if r['done']:
if r['done_reason'] == 'stop':
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r['done_reason'] == 'length':
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content(
r['message']['content'],
r["message"]["content"],
stop_reason,
)
yield ChatCompletionResponse(
@ -94,7 +109,7 @@ class OllamaInference(Inference):
stream = await self.client.chat(
model=self.model,
messages=self._messages_to_ollama_messages(request.messages),
stream=True
stream=True,
)
buffer = ""
@ -103,14 +118,14 @@ class OllamaInference(Inference):
async for chunk in stream:
# check if ollama is done
if chunk['done']:
if chunk['done_reason'] == 'stop':
if chunk["done"]:
if chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif chunk['done_reason'] == 'length':
elif chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk['message']['content']
text = chunk["message"]["content"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
@ -197,7 +212,7 @@ class OllamaInference(Inference):
)
#TODO: Consolidate this with impl in llama-models
# TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content(
content: str,
stop_reason: StopReason,