mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
clean up fireworks non-stream
This commit is contained in:
parent
c9ef88854c
commit
0f062a15ec
2 changed files with 5 additions and 15 deletions
|
@ -226,7 +226,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
if "messages" in params:
|
||||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
|
@ -235,10 +234,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
from rich.pretty import pprint
|
||||
|
||||
print("!! FIREWORKS STREAM PARAMS !!")
|
||||
pprint(params)
|
||||
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import (
|
||||
|
@ -42,6 +42,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
convert_image_content_to_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||
content: str
|
||||
|
@ -175,11 +177,6 @@ def process_chat_completion_response(
|
|||
formatter: ChatFormat,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
from rich.pretty import pprint
|
||||
|
||||
print("!! FIREWORKS NON_STREAM !!")
|
||||
pprint(response)
|
||||
|
||||
choice = response.choices[0]
|
||||
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
|
@ -191,21 +188,19 @@ def process_chat_completion_response(
|
|||
# response from the model.
|
||||
if raw_message.tool_calls:
|
||||
if not request.tools:
|
||||
print("Parsed tool calls but no tools provided in request")
|
||||
raw_message.tool_calls = []
|
||||
raw_message.content = text_from_choice(choice)
|
||||
# check if tool_calls is provided in the request
|
||||
elif request.tools:
|
||||
# only return tool_calls if provided in the request
|
||||
new_tool_calls = []
|
||||
request_tools = {t.tool_name: t for t in request.tools}
|
||||
for t in raw_message.tool_calls:
|
||||
if t.tool_name in request_tools:
|
||||
new_tool_calls.append(t)
|
||||
else:
|
||||
print(f"Tool {t.tool_name} not found in request tools")
|
||||
logger.warning(f"Tool {t.tool_name} not found in request tools")
|
||||
|
||||
if len(new_tool_calls) < len(raw_message.tool_calls):
|
||||
print("Some tool calls were not provided in the request")
|
||||
raw_message.tool_calls = new_tool_calls
|
||||
raw_message.content = text_from_choice(choice)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue