forked from phoenix-oss/llama-stack-mirror
A little clean up for the Fireworks and Together adapters
This commit is contained in:
parent
225cd75074
commit
6ad7365676
2 changed files with 31 additions and 164 deletions
|
@ -4,22 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
BuiltinTool,
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
CompletionMessage,
|
|
||||||
Message,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
@ -33,6 +28,8 @@ FIREWORKS_SUPPORTED_MODELS = {
|
||||||
class FireworksInferenceAdapter(Inference):
|
class FireworksInferenceAdapter(Inference):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> Fireworks:
|
def client(self) -> Fireworks:
|
||||||
|
@ -80,6 +77,8 @@ class FireworksInferenceAdapter(Inference):
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
messages = prepare_messages(request)
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to fireworks
|
# accumulate sampling params and other options to pass to fireworks
|
||||||
options = self.get_fireworks_chat_options(request)
|
options = self.get_fireworks_chat_options(request)
|
||||||
fireworks_model = self.resolve_fireworks_model(request.model)
|
fireworks_model = self.resolve_fireworks_model(request.model)
|
||||||
|
@ -87,7 +86,7 @@ class FireworksInferenceAdapter(Inference):
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
r = await self.client.chat.completions.acreate(
|
r = await self.client.chat.completions.acreate(
|
||||||
model=fireworks_model,
|
model=fireworks_model,
|
||||||
messages=self._messages_to_fireworks_messages(request.messages),
|
messages=self._messages_to_fireworks_messages(messages),
|
||||||
stream=False,
|
stream=False,
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
|
@ -98,10 +97,10 @@ class FireworksInferenceAdapter(Inference):
|
||||||
elif r.choices[0].finish_reason == "length":
|
elif r.choices[0].finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
completion_message = decode_assistant_message_from_content(
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
r.choices[0].message.content,
|
r.choices[0].message.content, stop_reason
|
||||||
stop_reason,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=completion_message,
|
completion_message=completion_message,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
|
@ -120,7 +119,7 @@ class FireworksInferenceAdapter(Inference):
|
||||||
|
|
||||||
async for chunk in self.client.chat.completions.acreate(
|
async for chunk in self.client.chat.completions.acreate(
|
||||||
model=fireworks_model,
|
model=fireworks_model,
|
||||||
messages=self._messages_to_fireworks_messages(request.messages),
|
messages=self._messages_to_fireworks_messages(messages),
|
||||||
stream=True,
|
stream=True,
|
||||||
**options,
|
**options,
|
||||||
):
|
):
|
||||||
|
@ -187,7 +186,9 @@ class FireworksInferenceAdapter(Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = decode_assistant_message_from_content(buffer, stop_reason)
|
message = self.formatter.decode_assistant_message_from_content(
|
||||||
|
buffer, stop_reason
|
||||||
|
)
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
@ -220,70 +221,3 @@ class FireworksInferenceAdapter(Inference):
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Consolidate this with impl in llama-models
|
|
||||||
def decode_assistant_message_from_content(
|
|
||||||
content: str,
|
|
||||||
stop_reason: StopReason,
|
|
||||||
) -> CompletionMessage:
|
|
||||||
ipython = content.startswith("<|python_tag|>")
|
|
||||||
if ipython:
|
|
||||||
content = content[len("<|python_tag|>") :]
|
|
||||||
|
|
||||||
if content.endswith("<|eot_id|>"):
|
|
||||||
content = content[: -len("<|eot_id|>")]
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif content.endswith("<|eom_id|>"):
|
|
||||||
content = content[: -len("<|eom_id|>")]
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
|
|
||||||
tool_name = None
|
|
||||||
tool_arguments = {}
|
|
||||||
|
|
||||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
|
||||||
if custom_tool_info is not None:
|
|
||||||
tool_name, tool_arguments = custom_tool_info
|
|
||||||
# Sometimes when agent has custom tools alongside builin tools
|
|
||||||
# Agent responds for builtin tool calls in the format of the custom tools
|
|
||||||
# This code tries to handle that case
|
|
||||||
if tool_name in BuiltinTool.__members__:
|
|
||||||
tool_name = BuiltinTool[tool_name]
|
|
||||||
tool_arguments = {
|
|
||||||
"query": list(tool_arguments.values())[0],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
|
||||||
if builtin_tool_info is not None:
|
|
||||||
tool_name, query = builtin_tool_info
|
|
||||||
tool_arguments = {
|
|
||||||
"query": query,
|
|
||||||
}
|
|
||||||
if tool_name in BuiltinTool.__members__:
|
|
||||||
tool_name = BuiltinTool[tool_name]
|
|
||||||
elif ipython:
|
|
||||||
tool_name = BuiltinTool.code_interpreter
|
|
||||||
tool_arguments = {
|
|
||||||
"code": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
if tool_name is not None and tool_arguments is not None:
|
|
||||||
call_id = str(uuid.uuid4())
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
call_id=call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
arguments=tool_arguments,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
content = ""
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
return CompletionMessage(
|
|
||||||
content=content,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
|
|
|
@ -4,21 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
BuiltinTool,
|
|
||||||
CompletionMessage,
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
Message,
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
from llama_models.llama3.api.tool_utils import ToolUtils
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
from llama_toolchain.inference.prepare_messages import prepare_messages
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
@ -32,6 +28,8 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
class TogetherInferenceAdapter(Inference):
|
class TogetherInferenceAdapter(Inference):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> Together:
|
def client(self) -> Together:
|
||||||
|
@ -82,12 +80,13 @@ class TogetherInferenceAdapter(Inference):
|
||||||
# accumulate sampling params and other options to pass to together
|
# accumulate sampling params and other options to pass to together
|
||||||
options = self.get_together_chat_options(request)
|
options = self.get_together_chat_options(request)
|
||||||
together_model = self.resolve_together_model(request.model)
|
together_model = self.resolve_together_model(request.model)
|
||||||
|
messages = prepare_messages(request)
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
# TODO: might need to add back an async here
|
# TODO: might need to add back an async here
|
||||||
r = self.client.chat.completions.create(
|
r = self.client.chat.completions.create(
|
||||||
model=together_model,
|
model=together_model,
|
||||||
messages=self._messages_to_together_messages(request.messages),
|
messages=self._messages_to_together_messages(messages),
|
||||||
stream=False,
|
stream=False,
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
|
@ -101,9 +100,8 @@ class TogetherInferenceAdapter(Inference):
|
||||||
elif r.choices[0].finish_reason == "length":
|
elif r.choices[0].finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
completion_message = decode_assistant_message_from_content(
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
r.choices[0].message.content,
|
r.choices[0].message.content, stop_reason
|
||||||
stop_reason,
|
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=completion_message,
|
completion_message=completion_message,
|
||||||
|
@ -123,7 +121,7 @@ class TogetherInferenceAdapter(Inference):
|
||||||
|
|
||||||
for chunk in self.client.chat.completions.create(
|
for chunk in self.client.chat.completions.create(
|
||||||
model=together_model,
|
model=together_model,
|
||||||
messages=self._messages_to_together_messages(request.messages),
|
messages=self._messages_to_together_messages(messages),
|
||||||
stream=True,
|
stream=True,
|
||||||
**options,
|
**options,
|
||||||
):
|
):
|
||||||
|
@ -194,7 +192,9 @@ class TogetherInferenceAdapter(Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = decode_assistant_message_from_content(buffer, stop_reason)
|
message = self.formatter.decode_assistant_message_from_content(
|
||||||
|
buffer, stop_reason
|
||||||
|
)
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
@ -227,70 +227,3 @@ class TogetherInferenceAdapter(Inference):
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Consolidate this with impl in llama-models
|
|
||||||
def decode_assistant_message_from_content(
|
|
||||||
content: str,
|
|
||||||
stop_reason: StopReason,
|
|
||||||
) -> CompletionMessage:
|
|
||||||
ipython = content.startswith("<|python_tag|>")
|
|
||||||
if ipython:
|
|
||||||
content = content[len("<|python_tag|>") :]
|
|
||||||
|
|
||||||
if content.endswith("<|eot_id|>"):
|
|
||||||
content = content[: -len("<|eot_id|>")]
|
|
||||||
stop_reason = StopReason.end_of_turn
|
|
||||||
elif content.endswith("<|eom_id|>"):
|
|
||||||
content = content[: -len("<|eom_id|>")]
|
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
|
|
||||||
tool_name = None
|
|
||||||
tool_arguments = {}
|
|
||||||
|
|
||||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
|
||||||
if custom_tool_info is not None:
|
|
||||||
tool_name, tool_arguments = custom_tool_info
|
|
||||||
# Sometimes when agent has custom tools alongside builin tools
|
|
||||||
# Agent responds for builtin tool calls in the format of the custom tools
|
|
||||||
# This code tries to handle that case
|
|
||||||
if tool_name in BuiltinTool.__members__:
|
|
||||||
tool_name = BuiltinTool[tool_name]
|
|
||||||
tool_arguments = {
|
|
||||||
"query": list(tool_arguments.values())[0],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
|
||||||
if builtin_tool_info is not None:
|
|
||||||
tool_name, query = builtin_tool_info
|
|
||||||
tool_arguments = {
|
|
||||||
"query": query,
|
|
||||||
}
|
|
||||||
if tool_name in BuiltinTool.__members__:
|
|
||||||
tool_name = BuiltinTool[tool_name]
|
|
||||||
elif ipython:
|
|
||||||
tool_name = BuiltinTool.code_interpreter
|
|
||||||
tool_arguments = {
|
|
||||||
"code": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
if tool_name is not None and tool_arguments is not None:
|
|
||||||
call_id = str(uuid.uuid4())
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
call_id=call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
arguments=tool_arguments,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
content = ""
|
|
||||||
|
|
||||||
if stop_reason is None:
|
|
||||||
stop_reason = StopReason.out_of_tokens
|
|
||||||
|
|
||||||
return CompletionMessage(
|
|
||||||
content=content,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue