This commit is contained in:
Aidan Do 2024-12-19 20:13:29 +11:00
parent 94645dd5f6
commit a2be32c27d
2 changed files with 19 additions and 11 deletions

View file

@ -32,6 +32,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
request_has_media,
)
from ..nvidia.openai_utils import _convert_tooldef_to_openai_tool, convert_openai_chat_completion_choice
from .config import FireworksImplConfig
@ -209,7 +210,9 @@ class FireworksInferenceAdapter(
) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
print(params)
r = await self._get_client().chat.completions.acreate(**params)
return convert_openai_chat_completion_choice(r.choices[0])
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
@ -240,14 +243,18 @@ class FireworksInferenceAdapter(
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
# print(input_dict["messages"])
if request.tool_choice == ToolChoice.required:
input_dict["tool_choice"] = "any"
if request.tools:
input_dict["tools"] = [
_convert_tooldef_to_openai_tool(t) for t in request.tools
]
# print(input_dict)
else:
assert (
not media_present

View file

@ -10,12 +10,13 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
# CompletionMessage,
StopReason,
TokenLogProbs,
# TokenLogProbs,
ToolCall,
ToolDefinition,
)
from llama_stack.apis.inference import CompletionMessage, TokenLogProbs
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
@ -339,7 +340,7 @@ def _convert_openai_tool_calls(
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
) -> Optional[List[Any]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.