diff --git a/llama_toolchain/inference/adapters/fireworks/fireworks.py b/llama_toolchain/inference/adapters/fireworks/fireworks.py index c9d6e38fd..b0eb41017 100644 --- a/llama_toolchain/inference/adapters/fireworks/fireworks.py +++ b/llama_toolchain/inference/adapters/fireworks/fireworks.py @@ -4,22 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import uuid from typing import AsyncGenerator from fireworks.client import Fireworks +from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - CompletionMessage, - Message, - StopReason, - ToolCall, -) -from llama_models.llama3.api.tool_utils import ToolUtils +from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.inference.prepare_messages import prepare_messages from .config import FireworksImplConfig @@ -33,6 +28,8 @@ FIREWORKS_SUPPORTED_MODELS = { class FireworksInferenceAdapter(Inference): def __init__(self, config: FireworksImplConfig) -> None: self.config = config + tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(tokenizer) @property def client(self) -> Fireworks: @@ -80,6 +77,8 @@ class FireworksInferenceAdapter(Inference): return options async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + messages = prepare_messages(request) + # accumulate sampling params and other options to pass to fireworks options = self.get_fireworks_chat_options(request) fireworks_model = self.resolve_fireworks_model(request.model) @@ -87,7 +86,7 @@ class FireworksInferenceAdapter(Inference): if not request.stream: r = await self.client.chat.completions.acreate( model=fireworks_model, - messages=self._messages_to_fireworks_messages(request.messages), + messages=self._messages_to_fireworks_messages(messages), stream=False, **options, ) @@ -98,10 +97,10 @@ class FireworksInferenceAdapter(Inference): elif r.choices[0].finish_reason == "length": stop_reason = StopReason.out_of_tokens - completion_message = decode_assistant_message_from_content( - r.choices[0].message.content, - stop_reason, + completion_message = self.formatter.decode_assistant_message_from_content( + r.choices[0].message.content, stop_reason ) + yield ChatCompletionResponse( completion_message=completion_message, logprobs=None, @@ -120,7 +119,7 @@ class FireworksInferenceAdapter(Inference): async for chunk in self.client.chat.completions.acreate( model=fireworks_model, - messages=self._messages_to_fireworks_messages(request.messages), + messages=self._messages_to_fireworks_messages(messages), stream=True, **options, ): @@ -187,7 +186,9 @@ class FireworksInferenceAdapter(Inference): ) # parse tool calls and report errors - message = decode_assistant_message_from_content(buffer, stop_reason) + message = self.formatter.decode_assistant_message_from_content( + buffer, stop_reason + ) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( @@ -220,70 +221,3 @@ class FireworksInferenceAdapter(Inference): stop_reason=stop_reason, ) ) - - -# TODO: Consolidate this with impl in llama-models -def decode_assistant_message_from_content( - content: str, - stop_reason: StopReason, -) -> CompletionMessage: - ipython = content.startswith("<|python_tag|>") - if ipython: - content = content[len("<|python_tag|>") :] - - if content.endswith("<|eot_id|>"): - content = content[: -len("<|eot_id|>")] - stop_reason = StopReason.end_of_turn - elif content.endswith("<|eom_id|>"): - content = content[: -len("<|eom_id|>")] - stop_reason = StopReason.end_of_message - - tool_name = None - tool_arguments = {} - - custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) - if custom_tool_info is not None: - tool_name, tool_arguments = custom_tool_info - # Sometimes when agent has custom tools alongside builin tools - # Agent responds for builtin tool calls in the format of the custom tools - # This code tries to handle that case - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] - tool_arguments = { - "query": list(tool_arguments.values())[0], - } - else: - builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) - if builtin_tool_info is not None: - tool_name, query = builtin_tool_info - tool_arguments = { - "query": query, - } - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] - elif ipython: - tool_name = BuiltinTool.code_interpreter - tool_arguments = { - "code": content, - } - - tool_calls = [] - if tool_name is not None and tool_arguments is not None: - call_id = str(uuid.uuid4()) - tool_calls.append( - ToolCall( - call_id=call_id, - tool_name=tool_name, - arguments=tool_arguments, - ) - ) - content = "" - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - return CompletionMessage( - content=content, - stop_reason=stop_reason, - tool_calls=tool_calls, - ) diff --git a/llama_toolchain/inference/adapters/together/together.py b/llama_toolchain/inference/adapters/together/together.py index b8f63df65..4800de6ad 100644 --- a/llama_toolchain/inference/adapters/together/together.py +++ b/llama_toolchain/inference/adapters/together/together.py @@ -4,21 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import uuid from typing import AsyncGenerator -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - CompletionMessage, - Message, - StopReason, - ToolCall, -) -from llama_models.llama3.api.tool_utils import ToolUtils +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from together import Together from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.inference.prepare_messages import prepare_messages from .config import TogetherImplConfig @@ -32,6 +28,8 @@ TOGETHER_SUPPORTED_MODELS = { class TogetherInferenceAdapter(Inference): def __init__(self, config: TogetherImplConfig) -> None: self.config = config + tokenizer = Tokenizer.get_instance() + self.formatter = ChatFormat(tokenizer) @property def client(self) -> Together: @@ -82,12 +80,13 @@ class TogetherInferenceAdapter(Inference): # accumulate sampling params and other options to pass to together options = self.get_together_chat_options(request) together_model = self.resolve_together_model(request.model) + messages = prepare_messages(request) if not request.stream: # TODO: might need to add back an async here r = self.client.chat.completions.create( model=together_model, - messages=self._messages_to_together_messages(request.messages), + messages=self._messages_to_together_messages(messages), stream=False, **options, ) @@ -101,9 +100,8 @@ class TogetherInferenceAdapter(Inference): elif r.choices[0].finish_reason == "length": stop_reason = StopReason.out_of_tokens - completion_message = decode_assistant_message_from_content( - r.choices[0].message.content, - stop_reason, + completion_message = self.formatter.decode_assistant_message_from_content( + r.choices[0].message.content, stop_reason ) yield ChatCompletionResponse( completion_message=completion_message, @@ -123,7 +121,7 @@ class TogetherInferenceAdapter(Inference): for chunk in self.client.chat.completions.create( model=together_model, - messages=self._messages_to_together_messages(request.messages), + messages=self._messages_to_together_messages(messages), stream=True, **options, ): @@ -194,7 +192,9 @@ class TogetherInferenceAdapter(Inference): ) # parse tool calls and report errors - message = decode_assistant_message_from_content(buffer, stop_reason) + message = self.formatter.decode_assistant_message_from_content( + buffer, stop_reason + ) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( @@ -227,70 +227,3 @@ class TogetherInferenceAdapter(Inference): stop_reason=stop_reason, ) ) - - -# TODO: Consolidate this with impl in llama-models -def decode_assistant_message_from_content( - content: str, - stop_reason: StopReason, -) -> CompletionMessage: - ipython = content.startswith("<|python_tag|>") - if ipython: - content = content[len("<|python_tag|>") :] - - if content.endswith("<|eot_id|>"): - content = content[: -len("<|eot_id|>")] - stop_reason = StopReason.end_of_turn - elif content.endswith("<|eom_id|>"): - content = content[: -len("<|eom_id|>")] - stop_reason = StopReason.end_of_message - - tool_name = None - tool_arguments = {} - - custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) - if custom_tool_info is not None: - tool_name, tool_arguments = custom_tool_info - # Sometimes when agent has custom tools alongside builin tools - # Agent responds for builtin tool calls in the format of the custom tools - # This code tries to handle that case - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] - tool_arguments = { - "query": list(tool_arguments.values())[0], - } - else: - builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) - if builtin_tool_info is not None: - tool_name, query = builtin_tool_info - tool_arguments = { - "query": query, - } - if tool_name in BuiltinTool.__members__: - tool_name = BuiltinTool[tool_name] - elif ipython: - tool_name = BuiltinTool.code_interpreter - tool_arguments = { - "code": content, - } - - tool_calls = [] - if tool_name is not None and tool_arguments is not None: - call_id = str(uuid.uuid4()) - tool_calls.append( - ToolCall( - call_id=call_id, - tool_name=tool_name, - arguments=tool_arguments, - ) - ) - content = "" - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - return CompletionMessage( - content=content, - stop_reason=stop_reason, - tool_calls=tool_calls, - )