mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -27,7 +27,7 @@ def supported_inference_models() -> List[Model]:
|
|||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3, ModelFamily.llama4}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
|
|
@ -33,9 +33,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_choice,
|
||||
|
@ -55,10 +53,22 @@ class LiteLLMOpenAIMixin(
|
|||
Inference,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
|
||||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
api_key_from_config: Optional[str],
|
||||
provider_data_api_key_field: str,
|
||||
openai_compat_api_base: str | None = None,
|
||||
):
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
self.api_base = openai_compat_api_base
|
||||
|
||||
if openai_compat_api_base:
|
||||
self.is_openai_compat = True
|
||||
else:
|
||||
self.is_openai_compat = False
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -98,6 +108,7 @@ class LiteLLMOpenAIMixin(
|
|||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -111,6 +122,9 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
if self.is_openai_compat:
|
||||
params["model"] = "openai/" + params["model"]
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
# caches various httpx.client objects in a non-eventloop aware manner
|
||||
|
@ -208,6 +222,7 @@ class LiteLLMOpenAIMixin(
|
|||
return {
|
||||
"model": request.model,
|
||||
"api_key": api_key,
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
|
|
|
@ -573,21 +573,24 @@ async def convert_message_to_openai_dict_new(
|
|||
content=await _convert_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
tool_calls = [
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
params = {"tool_calls": tool_calls}
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
or None,
|
||||
**params,
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
|
@ -801,7 +804,7 @@ def _convert_openai_logprobs(
|
|||
- token, logprob
|
||||
|
||||
"""
|
||||
if not logprobs:
|
||||
if not logprobs or not logprobs.content:
|
||||
return None
|
||||
|
||||
return [
|
||||
|
|
|
@ -224,7 +224,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
|||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest,
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -302,8 +304,12 @@ def chat_completion_request_to_messages(
|
|||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
@ -471,7 +477,11 @@ def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
|||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return ToolPromptFormat.json
|
||||
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||
elif llama_model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
return ToolPromptFormat.python_list
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue