From e61b3d91ef83cf4d8d4bdcc90a45ad8c8d84b5f3 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 26 Aug 2024 14:27:20 -0700 Subject: [PATCH] use a single impl for ChatFormat.decode_assistant_mesage --- llama_toolchain/inference/ollama/ollama.py | 90 +++------------------- 1 file changed, 10 insertions(+), 80 deletions(-) diff --git a/llama_toolchain/inference/ollama/ollama.py b/llama_toolchain/inference/ollama/ollama.py index 235cb20cc..b1e1ca09c 100644 --- a/llama_toolchain/inference/ollama/ollama.py +++ b/llama_toolchain/inference/ollama/ollama.py @@ -4,19 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import uuid from typing import AsyncGenerator, Dict import httpx -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - CompletionMessage, - Message, - StopReason, - ToolCall, -) -from llama_models.llama3.api.tool_utils import ToolUtils +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from ollama import AsyncClient @@ -57,6 +51,8 @@ async def get_provider_impl( class OllamaInference(Inference): def __init__(self, config: OllamaImplConfig) -> None: self.config = config + tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(tokenizer) @property def client(self) -> AsyncClient: @@ -144,9 +140,8 @@ class OllamaInference(Inference): elif r["done_reason"] == "length": stop_reason = StopReason.out_of_tokens - completion_message = decode_assistant_message_from_content( - r["message"]["content"], - stop_reason, + completion_message = self.formatter.decode_assistant_message_from_content( + r["message"]["content"], stop_reason ) yield ChatCompletionResponse( completion_message=completion_message, @@ -229,7 +224,9 @@ class OllamaInference(Inference): ) # parse tool calls and report errors - message = decode_assistant_message_from_content(buffer, stop_reason) + message = self.formatter.decode_assistant_message_from_content( + buffer, stop_reason + ) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( @@ -262,70 +259,3 @@ class OllamaInference(Inference): stop_reason=stop_reason, ) ) - - -# TODO: Consolidate this with impl in llama-models -def decode_assistant_message_from_content( - content: str, - stop_reason: StopReason, -) -> CompletionMessage: - ipython = content.startswith("<|python_tag|>") - if ipython: - content = content[len("<|python_tag|>") :] - - if content.endswith("<|eot_id|>"): - content = content[: -len("<|eot_id|>")] - stop_reason = StopReason.end_of_turn - elif content.endswith("<|eom_id|>"): - content = content[: -len("<|eom_id|>")] - stop_reason = StopReason.end_of_message - - tool_name = None - tool_arguments = {} - - custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) - if custom_tool_info is not None: - tool_name, tool_arguments = custom_tool_info - # Sometimes when agent has custom tools alongside builin tools - # Agent responds for builtin tool calls in the format of the custom tools - # This code tries to handle that case - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] - tool_arguments = { - "query": list(tool_arguments.values())[0], - } - else: - builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) - if builtin_tool_info is not None: - tool_name, query = builtin_tool_info - tool_arguments = { - "query": query, - } - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] - elif ipython: - tool_name = BuiltinTool.code_interpreter - tool_arguments = { - "code": content, - } - - tool_calls = [] - if tool_name is not None and tool_arguments is not None: - call_id = str(uuid.uuid4()) - tool_calls.append( - ToolCall( - call_id=call_id, - tool_name=tool_name, - arguments=tool_arguments, - ) - ) - content = "" - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - return CompletionMessage( - content=content, - stop_reason=stop_reason, - tool_calls=tool_calls, - )