clean up fireworks non-stream

This commit is contained in:
Xi Yan 2025-02-10 20:37:13 -08:00
parent c9ef88854c
commit 0f062a15ec
2 changed files with 5 additions and 15 deletions

View file

@ -226,7 +226,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
@ -235,10 +234,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
from rich.pretty import pprint
print("!! FIREWORKS STREAM PARAMS !!")
pprint(params)
async def _to_async_generator():
if "messages" in params:

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, Dict, List, Optional, Union
from llama_models.datatypes import (
@ -42,6 +42,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
logger = logging.getLogger(__name__)
class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str
@ -175,11 +177,6 @@ def process_chat_completion_response(
formatter: ChatFormat,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
from rich.pretty import pprint
print("!! FIREWORKS NON_STREAM !!")
pprint(response)
choice = response.choices[0]
raw_message = formatter.decode_assistant_message_from_content(
@ -191,21 +188,19 @@ def process_chat_completion_response(
# response from the model.
if raw_message.tool_calls:
if not request.tools:
print("Parsed tool calls but no tools provided in request")
raw_message.tool_calls = []
raw_message.content = text_from_choice(choice)
# check if tool_calls is provided in the request
elif request.tools:
# only return tool_calls if provided in the request
new_tool_calls = []
request_tools = {t.tool_name: t for t in request.tools}
for t in raw_message.tool_calls:
if t.tool_name in request_tools:
new_tool_calls.append(t)
else:
print(f"Tool {t.tool_name} not found in request tools")
logger.warning(f"Tool {t.tool_name} not found in request tools")
if len(new_tool_calls) < len(raw_message.tool_calls):
print("Some tool calls were not provided in the request")
raw_message.tool_calls = new_tool_calls
raw_message.content = text_from_choice(choice)