diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 975ec4893..0df05d8c8 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -32,6 +32,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, request_has_media, ) +from ..nvidia.openai_utils import _convert_tooldef_to_openai_tool, convert_openai_chat_completion_choice from .config import FireworksImplConfig @@ -209,10 +210,12 @@ class FireworksInferenceAdapter( ) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: + print(params) r = await self._get_client().chat.completions.acreate(**params) + return convert_openai_chat_completion_choice(r.choices[0]) else: r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest @@ -240,14 +243,18 @@ class FireworksInferenceAdapter( media_present = request_has_media(request) if isinstance(request, ChatCompletionRequest): - if media_present: - input_dict["messages"] = [ - await convert_message_to_openai_dict(m) for m in request.messages + input_dict["messages"] = [ + await convert_message_to_openai_dict(m) for m in request.messages + ] + # print(input_dict["messages"]) + if request.tool_choice == ToolChoice.required: + input_dict["tool_choice"] = "any" + + if request.tools: + input_dict["tools"] = [ + _convert_tooldef_to_openai_tool(t) for t in request.tools ] - else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model), self.formatter - ) + # print(input_dict) else: assert ( not media_present diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index ba8ff0fa4..1220f7ffa 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -10,12 +10,13 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional from llama_models.llama3.api.datatypes import ( BuiltinTool, - CompletionMessage, + # CompletionMessage, StopReason, - TokenLogProbs, + # TokenLogProbs, ToolCall, ToolDefinition, ) +from llama_stack.apis.inference import CompletionMessage, TokenLogProbs from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, @@ -339,7 +340,7 @@ def _convert_openai_tool_calls( def _convert_openai_logprobs( logprobs: OpenAIChoiceLogprobs, -) -> Optional[List[TokenLogProbs]]: +) -> Optional[List[Any]]: """ Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.