feat: add api.llama provider, llama-guard-4 model (#2058)

This PR adds a llama-stack inference provider for `api.llama.com`, as
well as adds entries for Llama-Guard-4 and updated Prompt-Guard models.
This commit is contained in:
Ashwin Bharambe 2025-04-29 10:07:41 -07:00 committed by GitHub
parent 934446ddb4
commit 4d0bfbf984
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1526 additions and 47 deletions

View file

@ -227,6 +227,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
return adapter

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class LlamaProviderDataValidator(BaseModel):
llama_api_key: Optional[str] = Field(
default=None,
description="API key for api.llama models",
)
@json_schema_type
class LlamaCompatConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from .models import MODEL_ENTRIES
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct-FP8",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -90,6 +90,9 @@ class LiteLLMOpenAIMixin(
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
return model
def get_litellm_model_name(self, model_id: str) -> str:
return "openai/" + model_id if self.is_openai_compat else model_id
async def completion(
self,
model_id: str,
@ -130,8 +133,7 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
if self.is_openai_compat:
params["model"] = "openai/" + params["model"]
params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
@ -220,21 +222,23 @@ class LiteLLMOpenAIMixin(
else request.tool_config.tool_choice
)
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
return api_key
async def embeddings(
self,
@ -247,7 +251,7 @@ class LiteLLMOpenAIMixin(
model = await self.model_store.get_model(model_id)
response = litellm.embedding(
model=model.provider_resource_id,
model=self.get_litellm_model_name(model.provider_resource_id),
input=[interleaved_content_as_str(content) for content in contents],
)
@ -278,7 +282,7 @@ class LiteLLMOpenAIMixin(
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt,
best_of=best_of,
echo=echo,
@ -297,6 +301,8 @@ class LiteLLMOpenAIMixin(
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**params)
@ -328,7 +334,7 @@ class LiteLLMOpenAIMixin(
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
@ -351,6 +357,8 @@ class LiteLLMOpenAIMixin(
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**params)

View file

@ -638,10 +638,13 @@ async def convert_message_to_openai_dict_new(
)
for tool in message.tool_calls
]
params = {}
if tool_calls:
params["tool_calls"] = tool_calls
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=tool_calls or None,
**params,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(