mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 16:54:42 +00:00
JKL
This commit is contained in:
parent
94645dd5f6
commit
a2be32c27d
2 changed files with 19 additions and 11 deletions
|
@ -32,6 +32,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
from ..nvidia.openai_utils import _convert_tooldef_to_openai_tool, convert_openai_chat_completion_choice
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
@ -209,7 +210,9 @@ class FireworksInferenceAdapter(
|
|||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
print(params)
|
||||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
return convert_openai_chat_completion_choice(r.choices[0])
|
||||
else:
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
@ -240,14 +243,18 @@ class FireworksInferenceAdapter(
|
|||
media_present = request_has_media(request)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
# print(input_dict["messages"])
|
||||
if request.tool_choice == ToolChoice.required:
|
||||
input_dict["tool_choice"] = "any"
|
||||
|
||||
if request.tools:
|
||||
input_dict["tools"] = [
|
||||
_convert_tooldef_to_openai_tool(t) for t in request.tools
|
||||
]
|
||||
# print(input_dict)
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
|
|
|
@ -10,12 +10,13 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
|||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
# CompletionMessage,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
# TokenLogProbs,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
from llama_stack.apis.inference import CompletionMessage, TokenLogProbs
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
|
@ -339,7 +340,7 @@ def _convert_openai_tool_calls(
|
|||
|
||||
def _convert_openai_logprobs(
|
||||
logprobs: OpenAIChoiceLogprobs,
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
) -> Optional[List[Any]]:
|
||||
"""
|
||||
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue