Merge branch 'main' into rag-metadata-support

This commit is contained in:
Francisco Arceo 2025-05-14 14:19:50 -06:00 committed by GitHub
commit b51427716d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 447 additions and 45 deletions

View file

@ -44,7 +44,7 @@ jobs:
- name: Install minikube
if: ${{ matrix.auth-provider == 'kubernetes' }}
uses: medyagh/setup-minikube@latest
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
- name: Start minikube
if: ${{ matrix.auth-provider == 'kubernetes' }}

View file

@ -106,6 +106,14 @@ repos:
pass_filenames: false
require_serial: true
files: ^llama_stack/apis/|^docs/openapi_generator/
- id: check-workflows-use-hashes
name: Check GitHub Actions use SHA-pinned actions
entry: ./scripts/check-workflows-use-hashes.sh
language: system
pass_filenames: false
require_serial: true
always_run: true
files: ^\.github/workflows/.*\.ya?ml$
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -4,27 +4,60 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
# the models w/ "openai/" prefix are the litellm specific model names.
# they should be deprecated in favor of the canonical openai model names.
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-4",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"gpt-4o-audio-preview",
"chatgpt-4o-latest",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
]
@dataclass
class EmbeddingModelInfo:
"""Structured representation of embedding model information."""
embedding_dimension: int
context_length: int
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-small",
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1536, "context_length": 8192},
),
ProviderModelEntry(
provider_model_id="openai/text-embedding-3-large",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 3072, "context_length": 8192},
),
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]

View file

@ -19,6 +19,13 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def initialize(self) -> None:
await super().initialize()

View file

@ -158,33 +158,29 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start
tool_call_buf = UnparseableToolCall()
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
continue
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
if choice.finish_reason:
args_str = tool_call_buf.arguments
args = None
try:
args = {} if not args_str else json.loads(args_str)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
if args:
yield ChatCompletionResponseStreamChunk(
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_buf: UnparseableToolCall,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
if tool_call_buf.tool_name:
# at least one tool call request is received
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
@ -196,8 +192,12 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
elif args_str:
yield ChatCompletionResponseStreamChunk(
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
@ -206,14 +206,50 @@ async def _process_vllm_chat_completion_stream_response(
),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
)
)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start
tool_call_buf = UnparseableToolCall()
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
# TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_buf=tool_call_buf,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -224,6 +260,17 @@ async def _process_vllm_chat_completion_stream_response(
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf
)
for c in chunks:
yield c
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:

View file

@ -62,6 +62,9 @@ class LiteLLMOpenAIMixin(
Inference,
NeedsRequestProviderData,
):
# TODO: avoid exposing the litellm specific model names to the user.
# potential change: add a prefix param that gets added to the model name
# when calling litellm.
def __init__(
self,
model_entries,
@ -92,7 +95,9 @@ class LiteLLMOpenAIMixin(
return model
def get_litellm_model_name(self, model_id: str) -> str:
return "openai/" + model_id if self.is_openai_compat else model_id
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
async def completion(
self,

View file

@ -149,6 +149,76 @@ models:
provider_id: openai
provider_model_id: openai/chatgpt-4o-latest
model_type: llm
- metadata: {}
model_id: gpt-3.5-turbo-0125
provider_id: openai
provider_model_id: gpt-3.5-turbo-0125
model_type: llm
- metadata: {}
model_id: gpt-3.5-turbo
provider_id: openai
provider_model_id: gpt-3.5-turbo
model_type: llm
- metadata: {}
model_id: gpt-3.5-turbo-instruct
provider_id: openai
provider_model_id: gpt-3.5-turbo-instruct
model_type: llm
- metadata: {}
model_id: gpt-4
provider_id: openai
provider_model_id: gpt-4
model_type: llm
- metadata: {}
model_id: gpt-4-turbo
provider_id: openai
provider_model_id: gpt-4-turbo
model_type: llm
- metadata: {}
model_id: gpt-4o
provider_id: openai
provider_model_id: gpt-4o
model_type: llm
- metadata: {}
model_id: gpt-4o-2024-08-06
provider_id: openai
provider_model_id: gpt-4o-2024-08-06
model_type: llm
- metadata: {}
model_id: gpt-4o-mini
provider_id: openai
provider_model_id: gpt-4o-mini
model_type: llm
- metadata: {}
model_id: gpt-4o-audio-preview
provider_id: openai
provider_model_id: gpt-4o-audio-preview
model_type: llm
- metadata: {}
model_id: chatgpt-4o-latest
provider_id: openai
provider_model_id: chatgpt-4o-latest
model_type: llm
- metadata: {}
model_id: o1
provider_id: openai
provider_model_id: o1
model_type: llm
- metadata: {}
model_id: o1-mini
provider_id: openai
provider_model_id: o1-mini
model_type: llm
- metadata: {}
model_id: o3-mini
provider_id: openai
provider_model_id: o3-mini
model_type: llm
- metadata: {}
model_id: o4-mini
provider_id: openai
provider_model_id: o4-mini
model_type: llm
- metadata:
embedding_dimension: 1536
context_length: 8192
@ -163,6 +233,20 @@ models:
provider_id: openai
provider_model_id: openai/text-embedding-3-large
model_type: embedding
- metadata:
embedding_dimension: 1536
context_length: 8192
model_id: text-embedding-3-small
provider_id: openai
provider_model_id: text-embedding-3-small
model_type: embedding
- metadata:
embedding_dimension: 3072
context_length: 8192
model_id: text-embedding-3-large
provider_id: openai
provider_model_id: text-embedding-3-large
model_type: embedding
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
provider_id: fireworks

View file

@ -151,6 +151,76 @@ models:
provider_id: openai
provider_model_id: openai/chatgpt-4o-latest
model_type: llm
- metadata: {}
model_id: gpt-3.5-turbo-0125
provider_id: openai
provider_model_id: gpt-3.5-turbo-0125
model_type: llm
- metadata: {}
model_id: gpt-3.5-turbo
provider_id: openai
provider_model_id: gpt-3.5-turbo
model_type: llm
- metadata: {}
model_id: gpt-3.5-turbo-instruct
provider_id: openai
provider_model_id: gpt-3.5-turbo-instruct
model_type: llm
- metadata: {}
model_id: gpt-4
provider_id: openai
provider_model_id: gpt-4
model_type: llm
- metadata: {}
model_id: gpt-4-turbo
provider_id: openai
provider_model_id: gpt-4-turbo
model_type: llm
- metadata: {}
model_id: gpt-4o
provider_id: openai
provider_model_id: gpt-4o
model_type: llm
- metadata: {}
model_id: gpt-4o-2024-08-06
provider_id: openai
provider_model_id: gpt-4o-2024-08-06
model_type: llm
- metadata: {}
model_id: gpt-4o-mini
provider_id: openai
provider_model_id: gpt-4o-mini
model_type: llm
- metadata: {}
model_id: gpt-4o-audio-preview
provider_id: openai
provider_model_id: gpt-4o-audio-preview
model_type: llm
- metadata: {}
model_id: chatgpt-4o-latest
provider_id: openai
provider_model_id: chatgpt-4o-latest
model_type: llm
- metadata: {}
model_id: o1
provider_id: openai
provider_model_id: o1
model_type: llm
- metadata: {}
model_id: o1-mini
provider_id: openai
provider_model_id: o1-mini
model_type: llm
- metadata: {}
model_id: o3-mini
provider_id: openai
provider_model_id: o3-mini
model_type: llm
- metadata: {}
model_id: o4-mini
provider_id: openai
provider_model_id: o4-mini
model_type: llm
- metadata:
embedding_dimension: 1536
context_length: 8192
@ -165,6 +235,20 @@ models:
provider_id: openai
provider_model_id: openai/text-embedding-3-large
model_type: embedding
- metadata:
embedding_dimension: 1536
context_length: 8192
model_id: text-embedding-3-small
provider_id: openai
provider_model_id: text-embedding-3-small
model_type: embedding
- metadata:
embedding_dimension: 3072
context_length: 8192
model_id: text-embedding-3-large
provider_id: openai
provider_model_id: text-embedding-3-large
model_type: embedding
- metadata: {}
model_id: accounts/fireworks/models/llama-v3p1-8b-instruct
provider_id: fireworks-openai-compat

View file

@ -0,0 +1,32 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
#
# Fails if any GitHub Actions workflow uses an external action without a full SHA pin.
set -euo pipefail
failed=0
# Find all workflow YAML files
for file in $(find .github/workflows/ -type f \( -name "*.yml" -o -name "*.yaml" \)); do
IFS=$'\n'
# Grep for `uses:` lines that look like actions
for line in $(grep -E '^.*uses:[^@]+@[^ ]+' "$file"); do
# Extract the ref part after the last @
ref=$(echo "$line" | sed -E 's/.*@([A-Za-z0-9._-]+).*/\1/')
# Check if ref is a 40-character hex string (full SHA).
#
# Note: strictly speaking, this could also be a tag or branch name, but
# we'd have to pull this info from the remote. Meh.
if ! [[ $ref =~ ^[0-9a-fA-F]{40}$ ]]; then
echo "ERROR: $file uses non-SHA action ref: $line"
failed=1
fi
done
done
exit $failed

View file

@ -374,3 +374,105 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
finish reason.
We want to make sure that this case is recognized and handled correctly, i.e., as a valid end of message.
"""
mock_tool_name = "mock_tool"
mock_tool_arguments = {"arg1": 0, "arg2": 100}
mock_tool_arguments_str = '"{\\"arg1\\": 0, \\"arg2\\": 100}"'
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": mock_tool_arguments_str,
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
Tool calls with no arguments should be treated as regular tool calls, which was not the case until now.
"""
mock_tool_name = "mock_tool"
async def mock_stream():
mock_chunks = [
OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
{
"delta": {
"content": None,
"tool_calls": [
{
"index": 0,
"id": "mock_id",
"type": "function",
"function": {
"name": mock_tool_name,
"arguments": "",
},
}
],
},
"finish_reason": None,
"logprobs": None,
"index": 0,
}
],
),
]
for chunk in mock_chunks:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
assert chunks[-2].event.delta.tool_call.arguments == {}