mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 03:13:56 +00:00
Merge branch 'main' into vectordb_name
This commit is contained in:
commit
74b0ab69ed
161 changed files with 1844 additions and 11065 deletions
|
|
@ -74,7 +74,6 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
|
|
@ -627,6 +626,8 @@ class OpenAIResponsesImpl:
|
|||
raise ValueError(f"Tool {tool_name} not found")
|
||||
chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "mcp":
|
||||
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
||||
|
||||
always_allowed = None
|
||||
never_allowed = None
|
||||
if input_tool.allowed_tools:
|
||||
|
|
@ -760,7 +761,9 @@ class OpenAIResponsesImpl:
|
|||
error_exc = None
|
||||
result = None
|
||||
try:
|
||||
if function.name in ctx.mcp_tool_to_server:
|
||||
if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server:
|
||||
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
|
||||
|
||||
mcp_tool = ctx.mcp_tool_to_server[function.name]
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
|
|
|
|||
|
|
@ -93,12 +93,17 @@ LLAMA_GUARD_MODEL_IDS = {
|
|||
"meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
CoreModelId.llama_guard_4_12b.value: "meta-llama/Llama-Guard-4-12B",
|
||||
"meta-llama/Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
|
||||
}
|
||||
|
||||
MODEL_TO_SAFETY_CATEGORIES_MAP = {
|
||||
"meta-llama/Llama-Guard-3-8B": DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
|
||||
"meta-llama/Llama-Guard-3-1B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
"meta-llama/Llama-Guard-3-11B-Vision": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
# Llama Guard 4 uses the same categories as Llama Guard 3
|
||||
# source: https://github.com/meta-llama/PurpleLlama/blob/main/Llama-Guard4/12B/MODEL_CARD.md
|
||||
"meta-llama/Llama-Guard-4-12B": DEFAULT_LG_V3_SAFETY_CATEGORIES,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,13 +20,9 @@ class TelemetrySink(StrEnum):
|
|||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_trace_endpoint: str | None = Field(
|
||||
otel_exporter_otlp_endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenTelemetry collector endpoint URL for traces",
|
||||
)
|
||||
otel_metric_endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="The OpenTelemetry collector endpoint URL for metrics",
|
||||
description="The OpenTelemetry collector endpoint URL (base URL for traces, metrics, and logs). If not set, the SDK will use OTEL_EXPORTER_OTLP_ENDPOINT environment variable.",
|
||||
)
|
||||
service_name: str = Field(
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
|
|
@ -35,7 +31,7 @@ class TelemetryConfig(BaseModel):
|
|||
)
|
||||
sinks: list[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
|
||||
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console)",
|
||||
)
|
||||
sqlite_db_path: str = Field(
|
||||
default_factory=lambda: (RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
|
||||
|
|
@ -55,4 +51,5 @@ class TelemetryConfig(BaseModel):
|
|||
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
"otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,24 +86,27 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
if self.config.otel_trace_endpoint is None:
|
||||
raise ValueError("otel_trace_endpoint is required when OTEL_TRACE is enabled")
|
||||
span_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_trace_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
if self.config.otel_metric_endpoint is None:
|
||||
raise ValueError("otel_metric_endpoint is required when OTEL_METRIC is enabled")
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_metric_endpoint,
|
||||
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks or TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
if self.config.otel_exporter_otlp_endpoint is None:
|
||||
raise ValueError(
|
||||
"otel_exporter_otlp_endpoint is required when OTEL_TRACE or OTEL_METRIC is enabled"
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
|
||||
# Let OpenTelemetry SDK handle endpoint construction automatically
|
||||
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
|
||||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
"pillow",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"mcp",
|
||||
"mcp>=1.8.1",
|
||||
]
|
||||
+ kvstore_dependencies(), # TODO make this dynamic based on the kvstore config
|
||||
module="llama_stack.providers.inline.agents.meta_reference",
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp"],
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -520,7 +520,7 @@ Please refer to the inline provider documentation.
|
|||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="milvus",
|
||||
pip_packages=["pymilvus[marshmallow<3.13.0]"],
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
description="""
|
||||
|
|
@ -633,7 +633,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
|||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus"],
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ class CerebrasImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": "${env.CEREBRAS_API_KEY}",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,13 +13,9 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
raise_on_connect_error: bool = True
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", raise_on_connect_error: bool = True, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"raise_on_connect_error": raise_on_connect_error,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ MODEL_ENTRIES = [
|
|||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:latest",
|
||||
provider_model_id="all-minilm:l6-v2",
|
||||
aliases=["all-minilm"],
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class OllamaInferenceAdapter(
|
|||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
self.url = config.url
|
||||
self.raise_on_connect_error = config.raise_on_connect_error
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
|
|
@ -108,10 +107,7 @@ class OllamaInferenceAdapter(
|
|||
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
if self.raise_on_connect_error:
|
||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
else:
|
||||
logger.warning("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -24,8 +24,10 @@ class PassthroughImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.PASSTHROUGH_URL}",
|
||||
"api_key": "${env.PASSTHROUGH_API_KEY}",
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,5 +26,5 @@ class RunpodImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:=}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN:=}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,11 @@ class TGIImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.TGI_URL}", **kwargs):
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -327,7 +327,6 @@ class InferenceEndpointAdapter(_HfAdapter):
|
|||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token.get_secret_value())
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
endpoint.wait(timeout=60)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
|
|||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -92,7 +92,20 @@ def content_from_data_and_mime_type(data: bytes | str, mime_type: str | None, en
|
|||
mime_category = mime_type.split("/")[0] if mime_type else None
|
||||
if mime_category == "text":
|
||||
# For text-based files (including CSV, MD)
|
||||
return data.decode(encoding)
|
||||
encodings_to_try = [encoding]
|
||||
if encoding != "utf-8":
|
||||
encodings_to_try.append("utf-8")
|
||||
first_exception = None
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
return data.decode(encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
if first_exception is None:
|
||||
first_exception = e
|
||||
log.warning(f"Decoding failed with {encoding}: {e}")
|
||||
# raise the origional exception, if we got here there was at least 1 exception
|
||||
log.error(f"Could not decode data as any of {encodings_to_try}")
|
||||
raise first_exception
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
return parse_pdf(data)
|
||||
|
|
@ -164,7 +177,8 @@ def make_overlapped_chunks(
|
|||
for i in range(0, len(tokens), window_len - overlap_len):
|
||||
toks = tokens[i : i + window_len]
|
||||
chunk = tokenizer.decode(toks)
|
||||
chunk_id = generate_chunk_id(chunk, text)
|
||||
chunk_window = f"{i}-{i + len(toks)}"
|
||||
chunk_id = generate_chunk_id(chunk, text, chunk_window)
|
||||
chunk_metadata = metadata.copy()
|
||||
chunk_metadata["chunk_id"] = chunk_id
|
||||
chunk_metadata["document_id"] = document_id
|
||||
|
|
@ -177,7 +191,7 @@ def make_overlapped_chunks(
|
|||
source=metadata.get("source", None),
|
||||
created_timestamp=metadata.get("created_timestamp", int(time.time())),
|
||||
updated_timestamp=int(time.time()),
|
||||
chunk_window=f"{i}-{i + len(toks)}",
|
||||
chunk_window=chunk_window,
|
||||
chunk_tokenizer=default_tokenizer,
|
||||
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
|
||||
content_token_count=len(toks),
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import hashlib
|
|||
import uuid
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | None = None) -> str:
|
||||
"""
|
||||
Generate a unique chunk ID using a hash of the document ID and chunk text.
|
||||
|
||||
|
|
@ -16,4 +16,6 @@ def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
|||
Adding usedforsecurity=False for compatibility with FIPS environments.
|
||||
"""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
if chunk_window:
|
||||
hash_input += f":{chunk_window}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input, usedforsecurity=False).hexdigest()))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue