mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Fix fireworks and update the test
Don't look for eom_id / eot_id sadly since providers don't return the last token
This commit is contained in:
parent
bbd3a02615
commit
dba7caf1d0
4 changed files with 37 additions and 37 deletions
|
|
@ -27,6 +27,8 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
|
||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -36,8 +38,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> Fireworks:
|
||||
|
|
@ -59,17 +61,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_fireworks_messages(self, messages: list[Message]) -> list:
|
||||
fireworks_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
fireworks_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return fireworks_messages
|
||||
|
||||
def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
|
|
@ -102,15 +93,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
fireworks_model = self.map_to_provider_model(request.model)
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat.completions.acreate(
|
||||
r = await self.client.completion.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
|
|
@ -122,7 +120,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r.choices[0].message.content, stop_reason
|
||||
r.choices[0].text, stop_reason
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
|
|
@ -141,9 +139,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in self.client.chat.completions.acreate(
|
||||
async for chunk in self.client.completion.acreate(
|
||||
model=fireworks_model,
|
||||
messages=self._messages_to_fireworks_messages(messages),
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
|
|
@ -157,7 +155,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk.choices[0].delta.content
|
||||
text = chunk.choices[0].text
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -64,17 +64,6 @@ class TogetherInferenceAdapter(
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _messages_to_together_messages(self, messages: list[Message]) -> list:
|
||||
together_messages = []
|
||||
for message in messages:
|
||||
if message.role == "ipython":
|
||||
role = "tool"
|
||||
else:
|
||||
role = message.role
|
||||
together_messages.append({"role": role, "content": message.content})
|
||||
|
||||
return together_messages
|
||||
|
||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue