diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 76218e87e..9a1ec7ee0 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -4,12 +4,41 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging +from collections.abc import AsyncIterator +from typing import Any + +from openai import AsyncOpenAI + +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from .config import OpenAIConfig from .models import MODEL_ENTRIES +logger = logging.getLogger(__name__) + +# +# This OpenAI adapter implements Inference methods using two clients - +# +# | Inference Method | Implementation Source | +# |----------------------------|--------------------------| +# | completion | LiteLLMOpenAIMixin | +# | chat_completion | LiteLLMOpenAIMixin | +# | embedding | LiteLLMOpenAIMixin | +# | batch_completion | LiteLLMOpenAIMixin | +# | batch_chat_completion | LiteLLMOpenAIMixin | +# | openai_completion | AsyncOpenAI | +# | openai_chat_completion | AsyncOpenAI | +# class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: OpenAIConfig) -> None: LiteLLMOpenAIMixin.__init__( @@ -26,9 +55,113 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # if we do not set this, users will be exposed to the # litellm specific model names, an abstraction leak. self.is_openai_compat = True + self._openai_client = AsyncOpenAI( + api_key=self.config.api_key, + ) async def initialize(self) -> None: await super().initialize() async def shutdown(self) -> None: await super().shutdown() + + async def openai_completion( + self, + model: str, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, + ) -> OpenAICompletion: + if guided_choice is not None: + logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.") + if prompt_logprobs is not None: + logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + + params = await prepare_openai_completion_params( + model=(await self.model_store.get_model(model)).provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._openai_client.completions.create(**params) + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + params = await prepare_openai_completion_params( + model=(await self.model_store.get_model(model)).provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + return await self._openai_client.chat.completions.create(**params)