llama-stack/llama_stack/providers/remote/inference/sambanova/sambanova.py
Ben Browning 2b2db5fbda
feat: OpenAI-Compatible models, completions, chat/completions (#1894)
# What does this PR do?

This stubs in some OpenAI server-side compatibility with three new
endpoints:

/v1/openai/v1/models
/v1/openai/v1/completions
/v1/openai/v1/chat/completions

This gives common inference apps using OpenAI clients the ability to
talk to Llama Stack using an endpoint like
http://localhost:8321/v1/openai/v1 .

The two "v1" instances in there isn't awesome, but the thinking is that
Llama Stack's API is v1 and then our OpenAI compatibility layer is
compatible with OpenAI V1. And, some OpenAI clients implicitly assume
the URL ends with "v1", so this gives maximum compatibility.

The openai models endpoint is implemented in the routing layer, and just
returns all the models Llama Stack knows about.

The following providers should be working with the new OpenAI
completions and chat/completions API:
* remote::anthropic (untested)
* remote::cerebras-openai-compat (untested)
* remote::fireworks (tested)
* remote::fireworks-openai-compat (untested)
* remote::gemini (untested)
* remote::groq-openai-compat (untested)
* remote::nvidia (tested)
* remote::ollama (tested)
* remote::openai (untested)
* remote::passthrough (untested)
* remote::sambanova-openai-compat (untested)
* remote::together (tested)
* remote::together-openai-compat (untested)
* remote::vllm (tested)

The goal to support this for every inference provider - proxying
directly to the provider's OpenAI endpoint for OpenAI-compatible
providers. For providers that don't have an OpenAI-compatible API, we'll
add a mixin to translate incoming OpenAI requests to Llama Stack
inference requests and translate the Llama Stack inference responses to
OpenAI responses.

This is related to #1817 but is a bit larger in scope than just chat
completions, as I have real use-cases that need the older completions
API as well.

## Test Plan

### vLLM

```
VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run

LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct"
```

### ollama
```
INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run

LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0"
```



## Documentation

Run a Llama Stack distribution that uses one of the providers mentioned
in the list above. Then, use your favorite OpenAI client to send
completion or chat completion requests with the base_url set to
http://localhost:8321/v1/openai/v1 . Replace "localhost:8321" with the
host and port of your Llama Stack server, if different.

---------

Signed-off-by: Ben Browning <bbrownin@redhat.com>
2025-04-11 13:14:17 -07:00

309 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import AsyncGenerator, List, Optional
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
TextTruncation,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self.config.api_key)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
request_sambanova = await self.convert_chat_completion_request(request)
if stream:
return self._stream_chat_completion(request_sambanova)
else:
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason),
tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _to_async_generator():
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages)
compatible_request["stream"] = request.stream
compatible_request["logprobs"] = False
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
}
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict:
params = {}
if sampling_params:
params["frequency_penalty"] = sampling_params.repetition_penalty
if sampling_params.max_tokens:
if legacy:
params["max_tokens"] = sampling_params.max_tokens
else:
params["max_completion_tokens"] = sampling_params.max_tokens
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
params["top_p"] = sampling_params.strategy.top_p
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
params["extra_body"]["top_k"] = sampling_params.strategy.top_k
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
params["temperature"] = 0.0
return params
async def convert_to_sambanova_messages(self, messages: List[Message]) -> List[dict]:
conversation = []
for message in messages:
content = {}
content["content"] = await self.convert_to_sambanova_content(message)
if isinstance(message, UserMessage):
content["role"] = "user"
elif isinstance(message, CompletionMessage):
content["role"] = "assistant"
tools = []
for tool_call in message.tool_calls:
tools.append(
{
"id": tool_call.call_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.arguments),
},
"type": "function",
}
)
content["tool_calls"] = tools
elif isinstance(message, ToolResponseMessage):
content["role"] = "tool"
content["tool_call_id"] = message.call_id
elif isinstance(message, SystemMessage):
content["role"] = "system"
conversation.append(content)
return conversation
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
url = await convert_image_content_to_url(content, download=True)
# A fix to make sure the call sucess.
components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return {
"type": "image_url",
"image_url": {"url": url},
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
content = [await _convert_content(c) for c in message.content]
else:
content = message.content
return content
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compatiable_tools = []
for tool in tools:
properties = {}
compatiable_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compatiable_required.append(tool_key)
compatiable_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compatiable_required,
},
},
}
compatiable_tools.append(compatiable_tool)
if len(compatiable_tools) > 0:
return compatiable_tools
return None
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def convert_to_sambanova_tool_calls(
self,
tool_calls,
) -> List[ToolCall]:
if not tool_calls:
return []
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
return compitable_tool_calls