forked from phoenix-oss/llama-stack-mirror
fix: OpenAI Completions API and Fireworks (#1997)
# What does this PR do? We were passing a dict into the compat mixin for OpenAI Completions when using Llama models with Fireworks, and that was breaking some strong typing code that was added in openai_compat.py. We shouldn't have been converting these params to a dict in that case anyway, so this adjusts things to pass the params in as their actual original types when calling the OpenAIChatCompletionToLlamaStackMixin. ## Test Plan All of the fireworks provider verification tests were failing due to some OpenAI compatibility cleanup in #1962. The changes in that PR were good to make, and this just cleans up the fireworks provider code to stop passing in untyped dicts to some of those `openai_compat.py` methods since we have the original strongly-typed parameters we can pass in. ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` ``` python -m pytest -s -v tests/verifications/openai_api/test_chat_completion.py --provider=fireworks-llama-stack ``` Before this PR, all of the fireworks OpenAI verification tests were failing. Now, most of them are passing. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
0d06c654d0
commit
602e949a46
1 changed files with 33 additions and 7 deletions
|
@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user=user,
|
||||
)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
||||
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue