forked from phoenix-oss/llama-stack-mirror
feat: OpenAI-Compatible models, completions, chat/completions (#1894)
# What does this PR do? This stubs in some OpenAI server-side compatibility with three new endpoints: /v1/openai/v1/models /v1/openai/v1/completions /v1/openai/v1/chat/completions This gives common inference apps using OpenAI clients the ability to talk to Llama Stack using an endpoint like http://localhost:8321/v1/openai/v1 . The two "v1" instances in there isn't awesome, but the thinking is that Llama Stack's API is v1 and then our OpenAI compatibility layer is compatible with OpenAI V1. And, some OpenAI clients implicitly assume the URL ends with "v1", so this gives maximum compatibility. The openai models endpoint is implemented in the routing layer, and just returns all the models Llama Stack knows about. The following providers should be working with the new OpenAI completions and chat/completions API: * remote::anthropic (untested) * remote::cerebras-openai-compat (untested) * remote::fireworks (tested) * remote::fireworks-openai-compat (untested) * remote::gemini (untested) * remote::groq-openai-compat (untested) * remote::nvidia (tested) * remote::ollama (tested) * remote::openai (untested) * remote::passthrough (untested) * remote::sambanova-openai-compat (untested) * remote::together (tested) * remote::together-openai-compat (untested) * remote::vllm (tested) The goal to support this for every inference provider - proxying directly to the provider's OpenAI endpoint for OpenAI-compatible providers. For providers that don't have an OpenAI-compatible API, we'll add a mixin to translate incoming OpenAI requests to Llama Stack inference requests and translate the Llama Stack inference responses to OpenAI responses. This is related to #1817 but is a bit larger in scope than just chat completions, as I have real use-cases that need the older completions API as well. ## Test Plan ### vLLM ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` ### ollama ``` INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0" ``` ## Documentation Run a Llama Stack distribution that uses one of the providers mentioned in the list above. Then, use your favorite OpenAI client to send completion or chat completion requests with the base_url set to http://localhost:8321/v1/openai/v1 . Replace "localhost:8321" with the host and port of your Llama Stack server, if different. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
24d70cedca
commit
2b2db5fbda
27 changed files with 3265 additions and 20 deletions
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
|
@ -38,6 +39,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
@ -67,7 +69,10 @@ from .models import model_entries
|
|||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
class OllamaInferenceAdapter(
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_entries)
|
||||
self.url = url
|
||||
|
@ -76,6 +81,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.url)
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
|
@ -319,6 +328,115 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
return model
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"prompt": prompt,
|
||||
"best_of": best_of,
|
||||
"echo": echo,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"logit_bias": logit_bias,
|
||||
"logprobs": logprobs,
|
||||
"max_tokens": max_tokens,
|
||||
"n": n,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"user": user,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"model": model_obj.provider_resource_id,
|
||||
"messages": messages,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"function_call": function_call,
|
||||
"functions": functions,
|
||||
"logit_bias": logit_bias,
|
||||
"logprobs": logprobs,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"n": n,
|
||||
"parallel_tool_calls": parallel_tool_calls,
|
||||
"presence_penalty": presence_penalty,
|
||||
"response_format": response_format,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options,
|
||||
"temperature": temperature,
|
||||
"tool_choice": tool_choice,
|
||||
"tools": tools,
|
||||
"top_logprobs": top_logprobs,
|
||||
"top_p": top_p,
|
||||
"user": user,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue