forked from phoenix-oss/llama-stack-mirror
A little clean up for the Fireworks and Together adapters
This commit is contained in:
parent
225cd75074
commit
6ad7365676
2 changed files with 31 additions and 164 deletions
|
@ -4,22 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
@ -33,6 +28,8 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
class FireworksInferenceAdapter(Inference):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Fireworks:
|
||||
|
@ -80,6 +77,8 @@ class FireworksInferenceAdapter(Inference):
|
|||
return options
|
||||
|
||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
messages = prepare_messages(request)
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
fireworks_model = self.resolve_fireworks_model(request.model)
|
||||
|
@ -87,7 +86,7 @@ class FireworksInferenceAdapter(Inference):
|
|||
if not request.stream:
|
||||
r = await self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(request.messages),
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
|
@ -98,10 +97,10 @@ class FireworksInferenceAdapter(Inference):
|
|||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = decode_assistant_message_from_content(
|
||||
r.choices[0].message.content,
|
||||
stop_reason,
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
|
@ -120,7 +119,7 @@ class FireworksInferenceAdapter(Inference):
|
|||
|
||||
async for chunk in self.client.chat.completions.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(request.messages),
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
|
@ -187,7 +186,9 @@ class FireworksInferenceAdapter(Inference):
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = decode_assistant_message_from_content(buffer, stop_reason)
|
||||
message = self.formatter.decode_assistant_message_from_content(
|
||||
buffer, stop_reason
|
||||
)
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -220,70 +221,3 @@ class FireworksInferenceAdapter(Inference):
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO: Consolidate this with impl in llama-models
|
||||
def decode_assistant_message_from_content(
|
||||
content: str,
|
||||
stop_reason: StopReason,
|
||||
) -> CompletionMessage:
|
||||
ipython = content.startswith("<|python_tag|>")
|
||||
if ipython:
|
||||
content = content[len("<|python_tag|>") :]
|
||||
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
tool_name, query = builtin_tool_info
|
||||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
||||
tool_calls = []
|
||||
if tool_name is not None and tool_arguments is not None:
|
||||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
return CompletionMessage(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue