forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This stubs in some OpenAI server-side compatibility with three new endpoints: /v1/openai/v1/models /v1/openai/v1/completions /v1/openai/v1/chat/completions This gives common inference apps using OpenAI clients the ability to talk to Llama Stack using an endpoint like http://localhost:8321/v1/openai/v1 . The two "v1" instances in there isn't awesome, but the thinking is that Llama Stack's API is v1 and then our OpenAI compatibility layer is compatible with OpenAI V1. And, some OpenAI clients implicitly assume the URL ends with "v1", so this gives maximum compatibility. The openai models endpoint is implemented in the routing layer, and just returns all the models Llama Stack knows about. The following providers should be working with the new OpenAI completions and chat/completions API: * remote::anthropic (untested) * remote::cerebras-openai-compat (untested) * remote::fireworks (tested) * remote::fireworks-openai-compat (untested) * remote::gemini (untested) * remote::groq-openai-compat (untested) * remote::nvidia (tested) * remote::ollama (tested) * remote::openai (untested) * remote::passthrough (untested) * remote::sambanova-openai-compat (untested) * remote::together (tested) * remote::together-openai-compat (untested) * remote::vllm (tested) The goal to support this for every inference provider - proxying directly to the provider's OpenAI endpoint for OpenAI-compatible providers. For providers that don't have an OpenAI-compatible API, we'll add a mixin to translate incoming OpenAI requests to Llama Stack inference requests and translate the Llama Stack inference responses to OpenAI responses. This is related to #1817 but is a bit larger in scope than just chat completions, as I have real use-cases that need the older completions API as well. ## Test Plan ### vLLM ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` ### ollama ``` INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0" ``` ## Documentation Run a Llama Stack distribution that uses one of the providers mentioned in the list above. Then, use your favorite OpenAI client to send completion or chat completion requests with the base_url set to http://localhost:8321/v1/openai/v1 . Replace "localhost:8321" with the host and port of your Llama Stack server, if different. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
459 lines
17 KiB
Python
459 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
import httpx
|
|
from ollama import AsyncClient
|
|
from openai import AsyncOpenAI
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
EmbeddingTaskType,
|
|
GrammarResponseFormat,
|
|
Inference,
|
|
JsonSchemaResponseFormat,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
TextTruncation,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
get_sampling_options,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
process_completion_response,
|
|
process_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
completion_request_to_prompt,
|
|
content_has_media,
|
|
convert_image_content_to_url,
|
|
interleaved_content_as_str,
|
|
request_has_media,
|
|
)
|
|
|
|
from .models import model_entries
|
|
|
|
logger = get_logger(name=__name__, category="inference")
|
|
|
|
|
|
class OllamaInferenceAdapter(
|
|
Inference,
|
|
ModelsProtocolPrivate,
|
|
):
|
|
def __init__(self, url: str) -> None:
|
|
self.register_helper = ModelRegistryHelper(model_entries)
|
|
self.url = url
|
|
|
|
@property
|
|
def client(self) -> AsyncClient:
|
|
return AsyncClient(host=self.url)
|
|
|
|
@property
|
|
def openai_client(self) -> AsyncOpenAI:
|
|
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
|
|
|
async def initialize(self) -> None:
|
|
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
|
try:
|
|
await self.client.ps()
|
|
except httpx.ConnectError as e:
|
|
raise RuntimeError(
|
|
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
|
) from e
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def _get_model(self, model_id: str) -> Model:
|
|
if not self.model_store:
|
|
raise ValueError("Model store not set")
|
|
return await self.model_store.get_model(model_id)
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self._get_model(model_id)
|
|
request = CompletionRequest(
|
|
model=model.provider_resource_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
return await self._nonstream_completion(request)
|
|
|
|
async def _stream_completion(
|
|
self, request: CompletionRequest
|
|
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_completion_stream_response(stream):
|
|
yield chunk
|
|
|
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
|
params = await self._get_params(request)
|
|
r = await self.client.generate(**params)
|
|
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
return process_completion_response(response)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self._get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
response_format=response_format,
|
|
tool_config=tool_config,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
|
sampling_options = get_sampling_options(request.sampling_params)
|
|
# This is needed since the Ollama API expects num_predict to be set
|
|
# for early truncation instead of max_tokens.
|
|
if sampling_options.get("max_tokens") is not None:
|
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
|
|
|
input_dict: dict[str, Any] = {}
|
|
media_present = request_has_media(request)
|
|
llama_model = self.register_helper.get_llama_model(request.model)
|
|
if isinstance(request, ChatCompletionRequest):
|
|
if media_present or not llama_model:
|
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
|
# flatten the list of lists
|
|
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
|
else:
|
|
input_dict["raw"] = True
|
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
|
request,
|
|
llama_model,
|
|
)
|
|
else:
|
|
assert not media_present, "Ollama does not support media for Completion requests"
|
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
input_dict["raw"] = True
|
|
|
|
if fmt := request.response_format:
|
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
|
input_dict["format"] = fmt.json_schema
|
|
elif isinstance(fmt, GrammarResponseFormat):
|
|
raise NotImplementedError("Grammar response format is not supported")
|
|
else:
|
|
raise ValueError(f"Unknown response format type: {fmt.type}")
|
|
|
|
params = {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"options": sampling_options,
|
|
"stream": request.stream,
|
|
}
|
|
logger.debug(f"params to ollama: {params}")
|
|
|
|
return params
|
|
|
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
params = await self._get_params(request)
|
|
if "messages" in params:
|
|
r = await self.client.chat(**params)
|
|
else:
|
|
r = await self.client.generate(**params)
|
|
|
|
if "message" in r:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=r["done_reason"] if r["done"] else None,
|
|
text=r["response"],
|
|
)
|
|
response = OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
return process_chat_completion_response(response, request)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest
|
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
|
params = await self._get_params(request)
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
if "messages" in params:
|
|
s = await self.client.chat(**params)
|
|
else:
|
|
s = await self.client.generate(**params)
|
|
async for chunk in s:
|
|
if "message" in chunk:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["message"]["content"],
|
|
)
|
|
else:
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
|
text=chunk["response"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(
|
|
choices=[choice],
|
|
)
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
yield chunk
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
output_dimension: Optional[int] = None,
|
|
task_type: Optional[EmbeddingTaskType] = None,
|
|
) -> EmbeddingsResponse:
|
|
model = await self._get_model(model_id)
|
|
|
|
assert all(not content_has_media(content) for content in contents), (
|
|
"Ollama does not support media for embeddings"
|
|
)
|
|
response = await self.client.embed(
|
|
model=model.provider_resource_id,
|
|
input=[interleaved_content_as_str(content) for content in contents],
|
|
)
|
|
embeddings = response["embeddings"]
|
|
|
|
return EmbeddingsResponse(embeddings=embeddings)
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
model = await self.register_helper.register_model(model)
|
|
if model.model_type == ModelType.embedding:
|
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
|
await self.client.pull(model.provider_resource_id)
|
|
# we use list() here instead of ps() -
|
|
# - ps() only lists running models, not available models
|
|
# - models not currently running are run by the ollama server as needed
|
|
response = await self.client.list()
|
|
available_models = [m["model"] for m in response["models"]]
|
|
if model.provider_resource_id not in available_models:
|
|
raise ValueError(
|
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
|
)
|
|
|
|
return model
|
|
|
|
async def openai_completion(
|
|
self,
|
|
model: str,
|
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
|
best_of: Optional[int] = None,
|
|
echo: Optional[bool] = None,
|
|
frequency_penalty: Optional[float] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
guided_choice: Optional[List[str]] = None,
|
|
prompt_logprobs: Optional[int] = None,
|
|
) -> OpenAICompletion:
|
|
if not isinstance(prompt, str):
|
|
raise ValueError("Ollama does not support non-string prompts for completion")
|
|
|
|
model_obj = await self._get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"prompt": prompt,
|
|
"best_of": best_of,
|
|
"echo": echo,
|
|
"frequency_penalty": frequency_penalty,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"presence_penalty": presence_penalty,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.completions.create(**params) # type: ignore
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[OpenAIMessageParam],
|
|
frequency_penalty: Optional[float] = None,
|
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
|
functions: Optional[List[Dict[str, Any]]] = None,
|
|
logit_bias: Optional[Dict[str, float]] = None,
|
|
logprobs: Optional[bool] = None,
|
|
max_completion_tokens: Optional[int] = None,
|
|
max_tokens: Optional[int] = None,
|
|
n: Optional[int] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
presence_penalty: Optional[float] = None,
|
|
response_format: Optional[Dict[str, str]] = None,
|
|
seed: Optional[int] = None,
|
|
stop: Optional[Union[str, List[str]]] = None,
|
|
stream: Optional[bool] = None,
|
|
stream_options: Optional[Dict[str, Any]] = None,
|
|
temperature: Optional[float] = None,
|
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
top_logprobs: Optional[int] = None,
|
|
top_p: Optional[float] = None,
|
|
user: Optional[str] = None,
|
|
) -> OpenAIChatCompletion:
|
|
model_obj = await self._get_model(model)
|
|
params = {
|
|
k: v
|
|
for k, v in {
|
|
"model": model_obj.provider_resource_id,
|
|
"messages": messages,
|
|
"frequency_penalty": frequency_penalty,
|
|
"function_call": function_call,
|
|
"functions": functions,
|
|
"logit_bias": logit_bias,
|
|
"logprobs": logprobs,
|
|
"max_completion_tokens": max_completion_tokens,
|
|
"max_tokens": max_tokens,
|
|
"n": n,
|
|
"parallel_tool_calls": parallel_tool_calls,
|
|
"presence_penalty": presence_penalty,
|
|
"response_format": response_format,
|
|
"seed": seed,
|
|
"stop": stop,
|
|
"stream": stream,
|
|
"stream_options": stream_options,
|
|
"temperature": temperature,
|
|
"tool_choice": tool_choice,
|
|
"tools": tools,
|
|
"top_logprobs": top_logprobs,
|
|
"top_p": top_p,
|
|
"user": user,
|
|
}.items()
|
|
if v is not None
|
|
}
|
|
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
|
|
|
|
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
|
async def _convert_content(content) -> dict:
|
|
if isinstance(content, ImageContentItem):
|
|
return {
|
|
"role": message.role,
|
|
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
|
|
}
|
|
else:
|
|
text = content.text if isinstance(content, TextContentItem) else content
|
|
assert isinstance(text, str)
|
|
return {
|
|
"role": message.role,
|
|
"content": text,
|
|
}
|
|
|
|
if isinstance(message.content, list):
|
|
return [await _convert_content(c) for c in message.content]
|
|
else:
|
|
return [await _convert_content(message.content)]
|