llama-stack/llama_stack/providers/remote/inference/nvidia/openai_utils.py
Hardik Shah a51c8b4efc
Convert SamplingParams.strategy to a union (#767)
# What does this PR do?

Cleans up how we provide sampling params. Earlier, strategy was an enum
and all params (top_p, temperature, top_k) across all strategies were
grouped. We now have a strategy union object with each strategy (greedy,
top_p, top_k) having its corresponding params.
Earlier, 
```
class SamplingParams: 
    strategy: enum ()
    top_p, temperature, top_k and other params
```
However, the `strategy` field was not being used in any providers making
it confusing to know the exact sampling behavior purely based on the
params since you could pass temperature, top_p, top_k and how the
provider would interpret those would not be clear.

Hence we introduced -- a union where the strategy and relevant params
are all clubbed together to avoid this confusion.

Have updated all providers, tests, notebooks, readme and otehr places
where sampling params was being used to use the new format.
   

## Test Plan
`pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py`
// inference on ollama, fireworks and together 
`with-proxy pytest -v -s -k "ollama"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py `
// agents on fireworks 
`pytest -v -s -k 'fireworks and create_agent'
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/agents/test_agents.py
--safety-shield="meta-llama/Llama-Guard-3-8B"`

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Ran pre-commit to handle lint / formatting issues.
- [X] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [X] Updated relevant documentation.
- [X] Wrote necessary unit or integration tests.

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
2025-01-15 05:38:51 -08:00

637 lines
21 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
TokenLogProbs,
ToolResponseMessage,
UserMessage,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
OpenAI spec -
{
"type": "function",
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
},
}
"""
out = {
"type": "function",
"function": {},
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description:
function.update(description=tool.description)
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.required:
required.append(param_name)
if required:
parameters.update(required=required)
function.update(parameters=parameters)
return out
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "tool":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=message.content, # TODO(mf): handle image content
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return out
def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# messages -> messages
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> GrammarResponseFormat TODO(mf)
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(
request.response_format, JsonSchemaResponseFormat
):
raise ValueError(
f"Unsupported response format: {request.response_format}. "
"Only JsonSchemaResponseFormat is supported."
)
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
)
if request.response_format:
# server bug - setting guided_json changes the behavior of response_format resulting in an error
# payload.update(response_format="json_object")
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
)
if request.tool_choice:
payload.update(
tool_choice=request.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
payload.update(logprobs=True)
payload.update(top_logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
strategy = request.sampling_params.strategy
if isinstance(strategy, TopPSamplingStrategy):
nvext.update(top_k=-1)
payload.update(top_p=strategy.top_p)
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
payload.update(temperature=strategy.temperature)
else:
raise ValueError(f"Unsupported sampling strategy: {strategy}")
return payload
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls", ...]
- stop: model hit a natural stop point or a provided stop sequence
- length: maximum number of tokens specified in the request was reached
- tool_calls: model called a tool
->
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# TODO(mf): are end_of_turn and end_of_message semantics correct?
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
OpenAI ChatCompletionMessageToolCall:
id: str
function: Function
type: Literal["function"]
OpenAI Function:
arguments: str
name: str
->
ToolCall:
call_id: str
tool_name: str
arguments: Dict[str, ...]
"""
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
OpenAI ChoiceLogprobs:
content: Optional[List[ChatCompletionTokenLogprob]]
OpenAI ChatCompletionTokenLogprob:
token: str
logprob: float
top_logprobs: List[TopLogprob]
OpenAI TopLogprob:
token: str
logprob: float
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content
]
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
"""
Convert an OpenAI Choice into a ChatCompletionResponse.
OpenAI Choice:
message: ChatCompletionMessage
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
OpenAI ChatCompletionMessage:
role: Literal["assistant"]
content: Optional[str]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
->
ChatCompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert (
hasattr(choice, "message") and choice.message
), "error in server response: message not found"
assert (
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
async def convert_openai_chat_completion_stream(
stream: AsyncStream[OpenAIChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> (
Generator[ChatCompletionResponseEventType, None, None]
):
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
# we implement NIM specific semantics, the main difference from OpenAI
# is that tool_calls are always produced as a complete call. there is no
# intermediate / partial tool call streamed. because of this, we can
# simplify the logic and not concern outselves with parse_status of
# started/in_progress/failed. we can always assume success.
#
# a stream of ChatCompletionResponseStreamChunk consists of
# 0. a start event
# 1. zero or more progress events
# - each progress event has a delta
# - each progress event may have a stop_reason
# - each progress event may have logprobs
# - each progress event may have tool_calls
# if a progress event has tool_calls,
# it is fully formed and
# can be emitted with a parse_status of success
# 2. a complete event
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest"
)
# NIM only produces fully formed tool calls, so we can assume success
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
def convert_completion_request(
request: CompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# prompt -> prompt
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> nvext.guided_json
# stream -> stream
# logprobs.top_k -> logprobs
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
prompt=request.content,
stream=request.stream,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
n=n,
)
if request.response_format:
# this is not openai compliant, it is a nim extension
nvext.update(guided_json=request.response_format.json_schema)
if request.logprobs:
payload.update(logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_completion_logprobs(
logprobs: Optional[OpenAICompletionLogprobs],
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
"""
if not logprobs:
return None
return [
TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs
]
def convert_openai_completion_choice(
choice: OpenAIChoice,
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
"""
return CompletionResponse(
content=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)
async def convert_openai_completion_stream(
stream: AsyncStream[OpenAICompletion],
) -> AsyncGenerator[CompletionResponse, None]:
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
delta=TextDelta(text=choice.text),
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)