From 0f062a15ec0068a1e3876ef19e2eb1e1fa9435f4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 10 Feb 2025 20:37:13 -0800 Subject: [PATCH] clean up fireworks non-stream --- .../remote/inference/fireworks/fireworks.py | 5 ----- .../providers/utils/inference/openai_compat.py | 15 +++++---------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 186036486..44666fa70 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -226,7 +226,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) - if "messages" in params: r = await self._get_client().chat.completions.acreate(**params) else: @@ -235,10 +234,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - from rich.pretty import pprint - - print("!! FIREWORKS STREAM PARAMS !!") - pprint(params) async def _to_async_generator(): if "messages" in params: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 3eb0977d5..066fda2c1 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import logging from typing import AsyncGenerator, Dict, List, Optional, Union from llama_models.datatypes import ( @@ -42,6 +42,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) +logger = logging.getLogger(__name__) + class OpenAICompatCompletionChoiceDelta(BaseModel): content: str @@ -175,11 +177,6 @@ def process_chat_completion_response( formatter: ChatFormat, request: ChatCompletionRequest, ) -> ChatCompletionResponse: - from rich.pretty import pprint - - print("!! FIREWORKS NON_STREAM !!") - pprint(response) - choice = response.choices[0] raw_message = formatter.decode_assistant_message_from_content( @@ -191,21 +188,19 @@ def process_chat_completion_response( # response from the model. if raw_message.tool_calls: if not request.tools: - print("Parsed tool calls but no tools provided in request") raw_message.tool_calls = [] raw_message.content = text_from_choice(choice) - # check if tool_calls is provided in the request elif request.tools: + # only return tool_calls if provided in the request new_tool_calls = [] request_tools = {t.tool_name: t for t in request.tools} for t in raw_message.tool_calls: if t.tool_name in request_tools: new_tool_calls.append(t) else: - print(f"Tool {t.tool_name} not found in request tools") + logger.warning(f"Tool {t.tool_name} not found in request tools") if len(new_tool_calls) < len(raw_message.tool_calls): - print("Some tool calls were not provided in the request") raw_message.tool_calls = new_tool_calls raw_message.content = text_from_choice(choice)