forked from phoenix-oss/llama-stack-mirror
feat: add api.llama provider, llama-guard-4 model (#2058)
This PR adds a llama-stack inference provider for `api.llama.com`, as well as adds entries for Llama-Guard-4 and updated Prompt-Guard models.
This commit is contained in:
parent
934446ddb4
commit
4d0bfbf984
21 changed files with 1526 additions and 47 deletions
907
docs/getting_started_llama_api.ipynb
Normal file
907
docs/getting_started_llama_api.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -460,15 +460,17 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||||
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||||
|
|
||||||
from .model.safety_models import (
|
from .model.safety_models import (
|
||||||
prompt_guard_download_info,
|
prompt_guard_download_info_map,
|
||||||
prompt_guard_model_sku,
|
prompt_guard_model_sku_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
||||||
|
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
||||||
|
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
if model_id == prompt_guard.model_id:
|
if model_id in prompt_guard_model_sku_map.keys():
|
||||||
model = prompt_guard
|
model = prompt_guard_model_sku_map[model_id]
|
||||||
info = prompt_guard_download_info()
|
info = prompt_guard_download_info_map[model_id]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(model_id)
|
model = resolve_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -36,11 +36,11 @@ class ModelDescribe(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku_map
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||||
if args.model_id == prompt_guard.model_id:
|
if args.model_id in prompt_guard_model_map.keys():
|
||||||
model = prompt_guard
|
model = prompt_guard_model_map[args.model_id]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(args.model_id)
|
model = resolve_model(args.model_id)
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@ class ModelList(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_skus
|
||||||
|
|
||||||
if args.downloaded:
|
if args.downloaded:
|
||||||
return _run_model_list_downloaded_cmd()
|
return _run_model_list_downloaded_cmd()
|
||||||
|
@ -96,7 +96,7 @@ class ModelList(Subcommand):
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for model in all_registered_models() + [prompt_guard_model_sku()]:
|
for model in all_registered_models() + prompt_guard_model_skus():
|
||||||
if not args.show_all and not model.is_featured:
|
if not args.show_all and not model.is_featured:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,12 @@ class ModelRemove(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from .safety_models import prompt_guard_model_sku
|
from .safety_models import prompt_guard_model_sku_map
|
||||||
|
|
||||||
prompt_guard = prompt_guard_model_sku()
|
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||||
if args.model == prompt_guard.model_id:
|
|
||||||
model = prompt_guard
|
if args.model in prompt_guard_model_map.keys():
|
||||||
|
model = prompt_guard_model_map[args.model]
|
||||||
else:
|
else:
|
||||||
model = resolve_model(args.model)
|
model = resolve_model(args.model)
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,11 @@ from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||||
class PromptGuardModel(BaseModel):
|
class PromptGuardModel(BaseModel):
|
||||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||||
|
|
||||||
model_id: str = "Prompt-Guard-86M"
|
model_id: str
|
||||||
|
huggingface_repo: str
|
||||||
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||||
is_featured: bool = False
|
is_featured: bool = False
|
||||||
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
|
max_seq_length: int = 512
|
||||||
max_seq_length: int = 2048
|
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
@ -30,13 +30,28 @@ class PromptGuardModel(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_model_sku():
|
def prompt_guard_model_skus():
|
||||||
return PromptGuardModel()
|
return [
|
||||||
|
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
|
||||||
|
PromptGuardModel(
|
||||||
|
model_id="Llama-Prompt-Guard-2-86M",
|
||||||
|
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
|
||||||
|
),
|
||||||
|
PromptGuardModel(
|
||||||
|
model_id="Llama-Prompt-Guard-2-22M",
|
||||||
|
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def prompt_guard_download_info():
|
def prompt_guard_model_sku_map() -> Dict[str, Any]:
|
||||||
return LlamaDownloadInfo(
|
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||||
folder="Prompt-Guard",
|
|
||||||
|
|
||||||
|
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
|
||||||
|
return {
|
||||||
|
model.model_id: LlamaDownloadInfo(
|
||||||
|
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||||
files=[
|
files=[
|
||||||
"model.safetensors",
|
"model.safetensors",
|
||||||
"special_tokens_map.json",
|
"special_tokens_map.json",
|
||||||
|
@ -45,3 +60,5 @@ def prompt_guard_download_info():
|
||||||
],
|
],
|
||||||
pth_size=1,
|
pth_size=1,
|
||||||
)
|
)
|
||||||
|
for model in prompt_guard_model_skus()
|
||||||
|
}
|
||||||
|
|
|
@ -792,6 +792,13 @@ def llama3_3_instruct_models() -> List[Model]:
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def safety_models() -> List[Model]:
|
def safety_models() -> List[Model]:
|
||||||
return [
|
return [
|
||||||
|
Model(
|
||||||
|
core_model_id=CoreModelId.llama_guard_4_12b,
|
||||||
|
description="Llama Guard v4 12b system safety model",
|
||||||
|
huggingface_repo="meta-llama/Llama-Guard-4-12B",
|
||||||
|
arch_args={},
|
||||||
|
pth_file_count=1,
|
||||||
|
),
|
||||||
Model(
|
Model(
|
||||||
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
||||||
description="Llama Guard v3 11b vision system safety model",
|
description="Llama Guard v3 11b vision system safety model",
|
||||||
|
|
|
@ -81,6 +81,7 @@ class CoreModelId(Enum):
|
||||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||||
|
llama_guard_4_12b = "Llama-Guard-4-12B"
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal(model_id) -> bool:
|
def is_multimodal(model_id) -> bool:
|
||||||
|
@ -148,6 +149,7 @@ def model_family(model_id) -> ModelFamily:
|
||||||
CoreModelId.llama_guard_2_8b,
|
CoreModelId.llama_guard_2_8b,
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
CoreModelId.llama_guard_3_1b,
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
CoreModelId.llama_guard_4_12b,
|
||||||
]:
|
]:
|
||||||
return ModelFamily.safety
|
return ModelFamily.safety
|
||||||
else:
|
else:
|
||||||
|
@ -225,5 +227,7 @@ class Model(BaseModel):
|
||||||
CoreModelId.llama_guard_3_1b,
|
CoreModelId.llama_guard_3_1b,
|
||||||
]:
|
]:
|
||||||
return 131072
|
return 131072
|
||||||
|
elif self.core_model_id == CoreModelId.llama_guard_4_12b:
|
||||||
|
return 8192
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
||||||
|
|
|
@ -227,6 +227,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="llama-openai-compat",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||||
|
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import LlamaCompatConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so the import is used only when it is needed
|
||||||
|
from .llama import LlamaCompatInferenceAdapter
|
||||||
|
|
||||||
|
adapter = LlamaCompatInferenceAdapter(config)
|
||||||
|
return adapter
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaProviderDataValidator(BaseModel):
|
||||||
|
llama_api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for api.llama models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class LlamaCompatConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Llama API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_compat_api_base: str = Field(
|
||||||
|
default="https://api.llama.com/compat/v1/",
|
||||||
|
description="The URL for the Llama API server",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||||
|
LlamaCompatConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||||
|
LiteLLMOpenAIMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
_config: LlamaCompatConfig
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaCompatConfig):
|
||||||
|
LiteLLMOpenAIMixin.__init__(
|
||||||
|
self,
|
||||||
|
model_entries=MODEL_ENTRIES,
|
||||||
|
api_key_from_config=config.api_key,
|
||||||
|
provider_data_api_key_field="llama_api_key",
|
||||||
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
await super().shutdown()
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_hf_repo_model_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"Llama-3.3-70B-Instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||||
|
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
|
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
|
@ -90,6 +90,9 @@ class LiteLLMOpenAIMixin(
|
||||||
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
|
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def get_litellm_model_name(self, model_id: str) -> str:
|
||||||
|
return "openai/" + model_id if self.is_openai_compat else model_id
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -130,8 +133,7 @@ class LiteLLMOpenAIMixin(
|
||||||
)
|
)
|
||||||
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if self.is_openai_compat:
|
params["model"] = self.get_litellm_model_name(params["model"])
|
||||||
params["model"] = "openai/" + params["model"]
|
|
||||||
|
|
||||||
logger.debug(f"params to litellm (openai compat): {params}")
|
logger.debug(f"params to litellm (openai compat): {params}")
|
||||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||||
|
@ -220,21 +222,23 @@ class LiteLLMOpenAIMixin(
|
||||||
else request.tool_config.tool_choice
|
else request.tool_config.tool_choice
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": request.model,
|
||||||
|
"api_key": self.get_api_key(),
|
||||||
|
"api_base": self.api_base,
|
||||||
|
**input_dict,
|
||||||
|
"stream": request.stream,
|
||||||
|
**get_sampling_options(request.sampling_params),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_api_key(self) -> str:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
key_field = self.provider_data_api_key_field
|
key_field = self.provider_data_api_key_field
|
||||||
if provider_data and getattr(provider_data, key_field, None):
|
if provider_data and getattr(provider_data, key_field, None):
|
||||||
api_key = getattr(provider_data, key_field)
|
api_key = getattr(provider_data, key_field)
|
||||||
else:
|
else:
|
||||||
api_key = self.api_key_from_config
|
api_key = self.api_key_from_config
|
||||||
|
return api_key
|
||||||
return {
|
|
||||||
"model": request.model,
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_base": self.api_base,
|
|
||||||
**input_dict,
|
|
||||||
"stream": request.stream,
|
|
||||||
**get_sampling_options(request.sampling_params),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -247,7 +251,7 @@ class LiteLLMOpenAIMixin(
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
response = litellm.embedding(
|
response = litellm.embedding(
|
||||||
model=model.provider_resource_id,
|
model=self.get_litellm_model_name(model.provider_resource_id),
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -278,7 +282,7 @@ class LiteLLMOpenAIMixin(
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id,
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
|
@ -297,6 +301,8 @@ class LiteLLMOpenAIMixin(
|
||||||
user=user,
|
user=user,
|
||||||
guided_choice=guided_choice,
|
guided_choice=guided_choice,
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
return await litellm.atext_completion(**params)
|
return await litellm.atext_completion(**params)
|
||||||
|
|
||||||
|
@ -328,7 +334,7 @@ class LiteLLMOpenAIMixin(
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id,
|
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
function_call=function_call,
|
function_call=function_call,
|
||||||
|
@ -351,6 +357,8 @@ class LiteLLMOpenAIMixin(
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
return await litellm.acompletion(**params)
|
return await litellm.acompletion(**params)
|
||||||
|
|
||||||
|
|
|
@ -638,10 +638,13 @@ async def convert_message_to_openai_dict_new(
|
||||||
)
|
)
|
||||||
for tool in message.tool_calls
|
for tool in message.tool_calls
|
||||||
]
|
]
|
||||||
|
params = {}
|
||||||
|
if tool_calls:
|
||||||
|
params["tool_calls"] = tool_calls
|
||||||
out = OpenAIChatCompletionAssistantMessage(
|
out = OpenAIChatCompletionAssistantMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=await _convert_message_content(message.content),
|
content=await _convert_message_content(message.content),
|
||||||
tool_calls=tool_calls or None,
|
**params,
|
||||||
)
|
)
|
||||||
elif isinstance(message, ToolResponseMessage):
|
elif isinstance(message, ToolResponseMessage):
|
||||||
out = OpenAIChatCompletionToolMessage(
|
out = OpenAIChatCompletionToolMessage(
|
||||||
|
|
|
@ -344,6 +344,45 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
"llama_api": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"emoji",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"langdetect",
|
||||||
|
"litellm",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"tree_sitter",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
"meta-reference-gpu": [
|
"meta-reference-gpu": [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
7
llama_stack/templates/llama_api/__init__.py
Normal file
7
llama_stack/templates/llama_api/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .llama_api import get_distribution_template # noqa: F401
|
33
llama_stack/templates/llama_api/build.yaml
Normal file
33
llama_stack/templates/llama_api/build.yaml
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Distribution for running e2e tests in CI
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::llama-openai-compat
|
||||||
|
- inline::sentence-transformers
|
||||||
|
vector_io:
|
||||||
|
- inline::sqlite-vec
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
159
llama_stack/templates/llama_api/llama_api.py
Normal file
159
llama_stack/templates/llama_api/llama_api.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelInput,
|
||||||
|
Provider,
|
||||||
|
ShieldInput,
|
||||||
|
ToolGroupInput,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
|
SQLiteVectorIOConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||||
|
LlamaCompatConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.inference.llama_openai_compat.models import (
|
||||||
|
MODEL_ENTRIES as LLLAMA_MODEL_ENTRIES,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
|
PGVectorVectorIOConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
get_model_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||||
|
# in this template, we allow each API key to be optional
|
||||||
|
providers = [
|
||||||
|
(
|
||||||
|
"llama-openai-compat",
|
||||||
|
LLLAMA_MODEL_ENTRIES,
|
||||||
|
LlamaCompatConfig.sample_run_config(api_key="${env.LLAMA_API_KEY:}"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inference_providers = []
|
||||||
|
available_models = {}
|
||||||
|
for provider_id, model_entries, config in providers:
|
||||||
|
inference_providers.append(
|
||||||
|
Provider(
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_type=f"remote::{provider_id}",
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
available_models[provider_id] = model_entries
|
||||||
|
return inference_providers, available_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
inference_providers, available_models = get_inference_providers()
|
||||||
|
providers = {
|
||||||
|
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||||
|
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
name = "llama_api"
|
||||||
|
|
||||||
|
vector_io_providers = [
|
||||||
|
Provider(
|
||||||
|
provider_id="sqlite-vec",
|
||||||
|
provider_type="inline::sqlite-vec",
|
||||||
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||||
|
),
|
||||||
|
Provider(
|
||||||
|
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
db="${env.PGVECTOR_DB:}",
|
||||||
|
user="${env.PGVECTOR_USER:}",
|
||||||
|
password="${env.PGVECTOR_PASSWORD:}",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id=embedding_provider.provider_id,
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name=name,
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Distribution for running e2e tests in CI",
|
||||||
|
container_image=None,
|
||||||
|
template_path=None,
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": inference_providers + [embedding_provider],
|
||||||
|
"vector_io": vector_io_providers,
|
||||||
|
},
|
||||||
|
default_models=default_models + [embedding_model],
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMA_STACK_PORT": (
|
||||||
|
"8321",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
167
llama_stack/templates/llama_api/run.yaml
Normal file
167
llama_stack/templates/llama_api/run.yaml
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: llama_api
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: llama-openai-compat
|
||||||
|
provider_type: remote::llama-openai-compat
|
||||||
|
config:
|
||||||
|
openai_compat_api_base: https://api.llama.com/compat/v1/
|
||||||
|
api_key: ${env.LLAMA_API_KEY:}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: sqlite-vec
|
||||||
|
provider_type: inline::sqlite-vec
|
||||||
|
config:
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/sqlite_vec.db
|
||||||
|
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||||
|
provider_type: remote::chromadb
|
||||||
|
config:
|
||||||
|
url: ${env.CHROMADB_URL:}
|
||||||
|
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||||
|
provider_type: remote::pgvector
|
||||||
|
config:
|
||||||
|
host: ${env.PGVECTOR_HOST:localhost}
|
||||||
|
port: ${env.PGVECTOR_PORT:5432}
|
||||||
|
db: ${env.PGVECTOR_DB:}
|
||||||
|
user: ${env.PGVECTOR_USER:}
|
||||||
|
password: ${env.PGVECTOR_PASSWORD:}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/llama_api/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/meta_reference_eval.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/huggingface_datasetio.db
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/localfs_datasetio.db
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: Llama-3.3-70B-Instruct
|
||||||
|
provider_id: llama-openai-compat
|
||||||
|
provider_model_id: Llama-3.3-70B-Instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: llama-openai-compat
|
||||||
|
provider_model_id: Llama-3.3-70B-Instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: Llama-4-Scout-17B-16E-Instruct-FP8
|
||||||
|
provider_id: llama-openai-compat
|
||||||
|
provider_model_id: Llama-4-Scout-17B-16E-Instruct-FP8
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||||
|
provider_id: llama-openai-compat
|
||||||
|
provider_model_id: Llama-4-Scout-17B-16E-Instruct-FP8
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||||
|
provider_id: llama-openai-compat
|
||||||
|
provider_model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||||
|
provider_id: llama-openai-compat
|
||||||
|
provider_model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
model_type: embedding
|
||||||
|
shields:
|
||||||
|
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
|
@ -321,6 +321,7 @@ exclude = [
|
||||||
"^llama_stack/strong_typing/serializer\\.py$",
|
"^llama_stack/strong_typing/serializer\\.py$",
|
||||||
"^llama_stack/templates/dev/dev\\.py$",
|
"^llama_stack/templates/dev/dev\\.py$",
|
||||||
"^llama_stack/templates/groq/groq\\.py$",
|
"^llama_stack/templates/groq/groq\\.py$",
|
||||||
|
"^llama_stack/templates/llama_api/llama_api\\.py$",
|
||||||
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
"^llama_stack/templates/sambanova/sambanova\\.py$",
|
||||||
"^llama_stack/templates/template\\.py$",
|
"^llama_stack/templates/template\\.py$",
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue