From 64f8d4c3adeddbad8133d2b7188c394c3e87f3ca Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 16 May 2025 15:57:56 -0400 Subject: [PATCH] feat: use openai-python for openai inference provider (#2193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? fixes #2121 this implementation splits reponsibility between litellm and openai libraries - | Inference Method | Implementation Source | |----------------------------|--------------------------| | completion | LiteLLMOpenAIMixin | | chat_completion | LiteLLMOpenAIMixin | | embedding | LiteLLMOpenAIMixin | | batch_completion | LiteLLMOpenAIMixin | | batch_chat_completion | LiteLLMOpenAIMixin | | openai_completion | AsyncOpenAI | | openai_chat_completion | AsyncOpenAI | ## Test Plan smoke test with - ``` $ OPENAI_API_KEY=$LLAMA_API_KEY OPENAI_BASE_URL=https://api.llama.com/compat/v1 llama stack build --image-type conda --image-name openai --providers inference=remote::openai --run $ llama-stack-client models register Llama-4-Scout-17B-16E-Instruct-FP8 $ curl "http://localhost:8321/v1/openai/v1/chat/completions" -H "Content-Type: application/json" \ -d '{ "model": "Llama-4-Scout-17B-16E-Instruct-FP8", "messages": [ {"role": "user", "content": "Hello Llama! Can you give me a quick intro?"} ] }' {"id":"AmPwrrkc5JgVjejPdIPrpT2","choices":[{"finish_reason":"stop","index":0,"logprobs":{"content":null,"refusal":null},"message":{"content":"Hello! I'm Llama, a Meta-designed model that adapts to your conversational style. Whether you need quick answers, deep dives into ideas, or just want to vent, joke, or brainstorm—I'm here for it. What’s on your mind?","refusal":"","role":"assistant","annotations":null,"audio":null,"function_call":null,"tool_calls":null,"id":"AmPwrrkc5JgVjejPdIPrpT2"}}],"created":1747410061,"model":"Llama-4-Scout-17B-16E-Instruct-FP8","object":"chat.completions","service_tier":null,"system_fingerprint":null,"usage":{"completion_tokens":54,"prompt_tokens":22,"total_tokens":76,"completion_tokens_details":null,"prompt_tokens_details":null}} ``` and run full test suite. --- .../remote/inference/openai/openai.py | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 76218e87e..9a1ec7ee0 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -4,12 +4,41 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +from collections.abc import AsyncIterator +from typing import Any + +from openai import AsyncOpenAI + +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from .config import OpenAIConfig from .models import MODEL_ENTRIES +logger = logging.getLogger(__name__) + +# +# This OpenAI adapter implements Inference methods using two clients - +# +# | Inference Method | Implementation Source | +# |----------------------------|--------------------------| +# | completion | LiteLLMOpenAIMixin | +# | chat_completion | LiteLLMOpenAIMixin | +# | embedding | LiteLLMOpenAIMixin | +# | batch_completion | LiteLLMOpenAIMixin | +# | batch_chat_completion | LiteLLMOpenAIMixin | +# | openai_completion | AsyncOpenAI | +# | openai_chat_completion | AsyncOpenAI | +# class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: OpenAIConfig) -> None: LiteLLMOpenAIMixin.__init__( @@ -26,9 +55,113 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # if we do not set this, users will be exposed to the # litellm specific model names, an abstraction leak. self.is_openai_compat = True + self._openai_client = AsyncOpenAI( + api_key=self.config.api_key, + ) async def initialize(self) -> None: await super().initialize() async def shutdown(self) -> None: await super().shutdown() + + async def openai_completion( + self, + model: str, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, + ) -> OpenAICompletion: + if guided_choice is not None: + logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.") + if prompt_logprobs is not None: + logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + + params = await prepare_openai_completion_params( + model=(await self.model_store.get_model(model)).provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._openai_client.completions.create(**params) + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + params = await prepare_openai_completion_params( + model=(await self.model_store.get_model(model)).provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self._openai_client.chat.completions.create(**params)