llama-stack/llama_stack/providers/utils/inference/openai_compat.py
Hardik Shah a51c8b4efc
Convert SamplingParams.strategy to a union (#767)
# What does this PR do?

Cleans up how we provide sampling params. Earlier, strategy was an enum
and all params (top_p, temperature, top_k) across all strategies were
grouped. We now have a strategy union object with each strategy (greedy,
top_p, top_k) having its corresponding params.
Earlier, 
```
class SamplingParams: 
    strategy: enum ()
    top_p, temperature, top_k and other params
```
However, the `strategy` field was not being used in any providers making
it confusing to know the exact sampling behavior purely based on the
params since you could pass temperature, top_p, top_k and how the
provider would interpret those would not be clear.

Hence we introduced -- a union where the strategy and relevant params
are all clubbed together to avoid this confusion.

Have updated all providers, tests, notebooks, readme and otehr places
where sampling params was being used to use the new format.
   

## Test Plan
`pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py`
// inference on ollama, fireworks and together 
`with-proxy pytest -v -s -k "ollama"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py `
// agents on fireworks 
`pytest -v -s -k 'fireworks and create_agent'
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/agents/test_agents.py
--safety-shield="meta-llama/Llama-Guard-3-8B"`

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Ran pre-commit to handle lint / formatting issues.
- [X] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [X] Updated relevant documentation.
- [X] Wrote necessary unit or integration tests.

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
2025-01-15 05:38:51 -08:00

322 lines
10 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
GreedySamplingStrategy,
SamplingParams,
StopReason,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Message,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
class OpenAICompatCompletionChoiceDelta(BaseModel):
content: str
class OpenAICompatCompletionChoice(BaseModel):
finish_reason: Optional[str] = None
text: Optional[str] = None
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
def get_sampling_strategy_options(params: SamplingParams) -> dict:
options = {}
if isinstance(params.strategy, GreedySamplingStrategy):
options["temperature"] = 0.0
elif isinstance(params.strategy, TopPSamplingStrategy):
options["temperature"] = params.strategy.temperature
options["top_p"] = params.strategy.top_p
elif isinstance(params.strategy, TopKSamplingStrategy):
options["top_k"] = params.strategy.top_k
else:
raise ValueError(f"Unsupported sampling strategy: {params.strategy}")
return options
def get_sampling_options(params: SamplingParams) -> dict:
options = {}
if params:
options.update(get_sampling_strategy_options(params))
if params.max_tokens:
options["max_tokens"] = params.max_tokens
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
return options
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content
if hasattr(choice, "message"):
return choice.message.content
return choice.text
def get_stop_reason(finish_reason: str) -> StopReason:
if finish_reason in ["stop", "eos"]:
return StopReason.end_of_turn
elif finish_reason == "eom":
return StopReason.end_of_message
elif finish_reason == "length":
return StopReason.out_of_tokens
return StopReason.out_of_tokens
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
)
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
)
def process_chat_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> ChatCompletionResponse:
choice = response.choices[0]
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=None,
)
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
)
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
yield CompletionResponseStreamChunk(
delta="",
stop_reason=stop_reason,
)
async def process_chat_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if not text:
# Sometimes you get empty chunks from providers
continue
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
if ipython:
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=text),
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(
content, download=download
),
},
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
content = [await _convert_content(c) for c in message.content]
else:
content = [await _convert_content(message.content)]
return {
"role": message.role,
"content": content,
}