Merge branch 'main' into vectordb_name

This commit is contained in:
Francisco Arceo 2025-07-09 20:53:46 -04:00 committed by GitHub
commit 36ca9543a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 2282 additions and 1644 deletions

View file

@ -24,8 +24,8 @@ class ShieldRunnerMixin:
def __init__(
self,
safety_api: Safety,
input_shields: list[str] = None,
output_shields: list[str] = None,
input_shields: list[str] | None = None,
output_shields: list[str] | None = None,
):
self.safety_api = safety_api
self.input_shields = input_shields
@ -37,6 +37,7 @@ class ShieldRunnerMixin:
return await self.safety_api.run_shield(
shield_id=identifier,
messages=messages,
params={},
)
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])

View file

@ -39,7 +39,7 @@ class MetaReferenceInferenceConfig(BaseModel):
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")

View file

@ -123,7 +123,8 @@ class TorchtunePostTrainingImpl:
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
) -> PostTrainingJob:
raise NotImplementedError()
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(

View file

@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
raise ValueError(
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
)
# Allow any model to be registered as a shield
# The model will be validated during runtime when making inference calls
pass
async def run_shield(
self,
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id]
# Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model_id in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield(
model=model,
model=model_id,
inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
)
return await impl.run(messages)
@ -183,20 +196,21 @@ class LlamaGuardShield:
model: str,
inference_api: Inference,
excluded_categories: list[str] | None = None,
safety_categories: list[str] | None = None,
):
if excluded_categories is None:
excluded_categories = []
if safety_categories is None:
safety_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response)
@ -214,7 +228,7 @@ class LlamaGuardShield:
final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
all_categories = self.safety_categories
for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:

View file

@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
"anthropic/claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="anthropic/voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0",
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1-8b",
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
]
build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
model_entries = [
SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
@ -56,7 +59,7 @@ model_entries = [
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
]
] + SAFETY_MODELS_ENTRIES
class DatabricksInferenceAdapter(
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:

View file

@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value,
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
"context_length": 8192,
},
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
"gemini/gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="gemini/text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry,
)
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"groq/llama3-8b-8192",
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama3-8b-instruct",
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]
] + SAFETY_MODELS_ENTRIES

View file

@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
ProviderModelEntry(
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
class RunpodInferenceAdapter(
ModelRegistryHelper,
@ -61,25 +71,25 @@ class RunpodInferenceAdapter(
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -129,10 +139,10 @@ class RunpodInferenceAdapter(
async def embeddings(
self,
model: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct",
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding,
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
],
),
]
] + SAFETY_MODELS_ENTRIES

View file

@ -14,6 +14,6 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -16,6 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.

View file

@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
]
if additional_aliases:
aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=aliases,
@ -90,7 +91,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)