feat: use openai-python for openai inference provider (#2193)

# What does this PR do?

fixes #2121

this implementation splits reponsibility between litellm and openai
libraries -

 | Inference Method           | Implementation Source    |
 |----------------------------|--------------------------|
 | completion                 | LiteLLMOpenAIMixin       |
 | chat_completion            | LiteLLMOpenAIMixin       |
 | embedding                  | LiteLLMOpenAIMixin       |
 | batch_completion           | LiteLLMOpenAIMixin       |
 | batch_chat_completion      | LiteLLMOpenAIMixin       |
 | openai_completion          | AsyncOpenAI              |
 | openai_chat_completion     | AsyncOpenAI              |

## Test Plan

smoke test with -
```
$ OPENAI_API_KEY=$LLAMA_API_KEY OPENAI_BASE_URL=https://api.llama.com/compat/v1 llama stack build --image-type conda --image-name openai --providers inference=remote::openai --run

$ llama-stack-client models register Llama-4-Scout-17B-16E-Instruct-FP8

$ curl "http://localhost:8321/v1/openai/v1/chat/completions" -H "Content-Type: application/json" \ -d '{
      "model": "Llama-4-Scout-17B-16E-Instruct-FP8",
      "messages": [
        {"role": "user", "content": "Hello Llama! Can you give me a quick intro?"}
      ]
}'
{"id":"AmPwrrkc5JgVjejPdIPrpT2","choices":[{"finish_reason":"stop","index":0,"logprobs":{"content":null,"refusal":null},"message":{"content":"Hello! I'm Llama, a Meta-designed model that adapts to your conversational style. Whether you need quick answers, deep dives into ideas, or just want to vent, joke, or brainstorm—I'm here for it. What’s on your mind?","refusal":"","role":"assistant","annotations":null,"audio":null,"function_call":null,"tool_calls":null,"id":"AmPwrrkc5JgVjejPdIPrpT2"}}],"created":1747410061,"model":"Llama-4-Scout-17B-16E-Instruct-FP8","object":"chat.completions","service_tier":null,"system_fingerprint":null,"usage":{"completion_tokens":54,"prompt_tokens":22,"total_tokens":76,"completion_tokens_details":null,"prompt_tokens_details":null}}
```

and run full test suite.
This commit is contained in:
Matthew Farrellee 2025-05-16 15:57:56 -04:00 committed by GitHub
parent 953ccffca2
commit 64f8d4c3ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,12 +4,41 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
#
# This OpenAI adapter implements Inference methods using two clients -
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI |
# | openai_chat_completion | AsyncOpenAI |
#
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
@ -26,9 +55,113 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
self._openai_client = AsyncOpenAI(
api_key=self.config.api_key,
)
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._openai_client.completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._openai_client.chat.completions.create(**params)