Merge branch 'main' into chore/standard-unsupported-model-err-msg-2517

This commit is contained in:
Rohan Awhad 2025-06-26 10:53:28 -04:00
commit 92d934e476
196 changed files with 2335 additions and 1516 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Any, Protocol
from urllib.parse import urlparse
@ -225,7 +225,7 @@ def remote_provider_spec(
)
class HealthStatus(str, Enum):
class HealthStatus(StrEnum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"

View file

@ -42,9 +42,10 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
)
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -583,7 +584,7 @@ class OpenAIResponsesImpl:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
)
from llama_stack.apis.tools.tools import Tool
from llama_stack.apis.tools import Tool
mcp_tool_to_server = {}
@ -609,7 +610,7 @@ class OpenAIResponsesImpl:
# TODO: Handle other tool types
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:

View file

@ -208,7 +208,7 @@ class MetaReferenceEvalImpl(
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {scoring_fn_id: None for scoring_fn_id in scoring_functions}
scoring_functions_dict = dict.fromkeys(scoring_functions)
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict

View file

@ -23,7 +23,7 @@ class LocalfsFilesImplConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"storage_dir": "${env.FILES_STORAGE_DIR:" + __distro_dir__ + "/files}",
"storage_dir": "${env.FILES_STORAGE_DIR:=" + __distro_dir__ + "/files}",
"metadata_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="files_metadata.db",

View file

@ -49,11 +49,11 @@ class MetaReferenceInferenceConfig(BaseModel):
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
**kwargs,
) -> dict[str, Any]:
return {

View file

@ -44,10 +44,10 @@ class VLLMConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
"max_tokens": "${env.MAX_TOKENS:=4096}",
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
}

View file

@ -17,5 +17,5 @@ class BraintrustScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"openai_api_key": "${env.OPENAI_API_KEY:}",
"openai_api_key": "${env.OPENAI_API_KEY:+}",
}

View file

@ -7,7 +7,7 @@ from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, field_validator
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
class TelemetrySink(StrEnum):
OTEL_TRACE = "otel_trace"
OTEL_METRIC = "otel_metric"
SQLITE = "sqlite"
@ -20,12 +20,12 @@ class TelemetrySink(str, Enum):
class TelemetryConfig(BaseModel):
otel_trace_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
otel_trace_endpoint: str | None = Field(
default=None,
description="The OpenTelemetry collector endpoint URL for traces",
)
otel_metric_endpoint: str = Field(
default="http://localhost:4318/v1/metrics",
otel_metric_endpoint: str | None = Field(
default=None,
description="The OpenTelemetry collector endpoint URL for metrics",
)
service_name: str = Field(
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:\u200b}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}

View file

@ -87,12 +87,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
trace.set_tracer_provider(provider)
_TRACER_PROVIDER = provider
if TelemetrySink.OTEL_TRACE in self.config.sinks:
if self.config.otel_trace_endpoint is None:
raise ValueError("otel_trace_endpoint is required when OTEL_TRACE is enabled")
span_exporter = OTLPSpanExporter(
endpoint=self.config.otel_trace_endpoint,
)
span_processor = BatchSpanProcessor(span_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
if TelemetrySink.OTEL_METRIC in self.config.sinks:
if self.config.otel_metric_endpoint is None:
raise ValueError("otel_metric_endpoint is required when OTEL_METRIC is enabled")
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_metric_endpoint,

View file

@ -81,6 +81,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
chunks = []
for doc in documents:
content = await content_from_doc(doc)
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
chunks.extend(
make_overlapped_chunks(
doc.document_id,
@ -157,8 +158,24 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
break
metadata_subset = {k: v for k, v in metadata.items() if k not in ["token_count", "metadata_token_count"]}
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_subset)
# Add useful keys from chunk_metadata to metadata and remove some from metadata
chunk_metadata_keys_to_include_from_context = [
"chunk_id",
"document_id",
"source",
]
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
metadata_for_context[k] = getattr(chunk.chunk_metadata, k)
for k in metadata:
if k not in metadata_keys_to_exclude_from_context:
metadata_for_context[k] = metadata[k]
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))

View file

@ -16,8 +16,7 @@ import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,

View file

@ -19,5 +19,5 @@ class QdrantVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
"path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db",
}

View file

@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
}

View file

@ -5,20 +5,18 @@
# the root directory of this source tree.
import asyncio
import hashlib
import json
import logging
import sqlite3
import struct
import uuid
from typing import Any
import numpy as np
import sqlite_vec
from numpy.typing import NDArray
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -66,7 +64,7 @@ def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return {doc_id: 1.0 for doc_id in scores}
return dict.fromkeys(scores, 1.0)
def _weighted_rerank(
@ -201,10 +199,7 @@ class SQLiteVecIndex(EmbeddingIndex):
batch_embeddings = embeddings[i : i + batch_size]
# Insert metadata
metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks
]
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
@ -218,7 +213,7 @@ class SQLiteVecIndex(EmbeddingIndex):
embedding_data = [
(
(
generate_chunk_id(chunk.metadata["document_id"], chunk.content),
chunk.chunk_id,
serialize_vector(emb.tolist()),
)
)
@ -230,10 +225,7 @@ class SQLiteVecIndex(EmbeddingIndex):
)
# Insert FTS content
fts_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
for chunk in batch_chunks
]
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
@ -381,13 +373,12 @@ class SQLiteVecIndex(EmbeddingIndex):
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using generate_chunk_id
# Convert responses to score dictionaries using chunk_id
vector_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
chunk.chunk_id: score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
@ -408,13 +399,7 @@ class SQLiteVecIndex(EmbeddingIndex):
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
# Use the map to look up chunks by their IDs
chunks = []
@ -757,9 +742,3 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))

View file

@ -70,7 +70,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama", "aiohttp"],
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
),

View file

@ -67,7 +67,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::basic",
pip_packages=[],
pip_packages=["requests"],
module="llama_stack.providers.inline.scoring.basic",
config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig",
api_dependencies=[

View file

@ -54,8 +54,8 @@ class NvidiaDatasetIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"datasets_url": "${env.NVIDIA_DATASETS_URL:http://nemo.test}",
"api_key": "${env.NVIDIA_API_KEY:+}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
"datasets_url": "${env.NVIDIA_DATASETS_URL:=http://nemo.test}",
}

View file

@ -66,7 +66,7 @@ class NvidiaDatasetIOAdapter:
Returns:
Dataset
"""
## add warnings for unsupported params
# add warnings for unsupported params
request_body = {
"name": dataset_def.identifier,
"namespace": self.config.dataset_namespace,

View file

@ -25,5 +25,5 @@ class NVIDIAEvalConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}",
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)

View file

@ -24,6 +24,12 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -33,14 +39,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)

View file

@ -9,7 +9,7 @@ from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChoiceDelta,

View file

@ -55,7 +55,7 @@ class NVIDIAConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
"url": "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:+}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:=True}",
}

View file

@ -29,20 +29,18 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
impl = OllamaInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -13,7 +13,13 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
raise_on_connect_error: bool = True
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> dict[str, Any]:
return {"url": url}
def sample_run_config(
cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", raise_on_connect_error: bool = True, **kwargs
) -> dict[str, Any]:
return {
"url": url,
"raise_on_connect_error": raise_on_connect_error,
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,

View file

@ -9,7 +9,6 @@ import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
@ -33,15 +32,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
@ -49,6 +39,13 @@ from llama_stack.apis.inference.inference import (
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.exceptions import UnsupportedModelError
@ -58,6 +55,7 @@ from llama_stack.providers.datatypes import (
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -91,9 +89,10 @@ class OllamaInferenceAdapter(
InferenceProvider,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.url = url
self.url = config.url
self.raise_on_connect_error = config.raise_on_connect_error
@property
def client(self) -> AsyncClient:
@ -104,8 +103,13 @@ class OllamaInferenceAdapter(
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
await self.health()
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR:
if self.raise_on_connect_error:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
else:
logger.warning("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
async def health(self) -> HealthResponse:
"""
@ -118,10 +122,8 @@ class OllamaInferenceAdapter(
try:
await self.client.ps()
return HealthResponse(status=HealthStatus.OK)
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
) from e
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None:
pass

View file

@ -6,7 +6,7 @@
from dataclasses import dataclass
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)

View file

@ -10,7 +10,7 @@ from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,

View file

@ -19,7 +19,12 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -28,13 +33,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -25,6 +25,6 @@ class RunpodImplConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"url": "${env.RUNPOD_URL:}",
"api_token": "${env.RUNPOD_API_TOKEN:}",
"url": "${env.RUNPOD_URL:+}",
"api_token": "${env.RUNPOD_API_TOKEN:+}",
}

View file

@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:}",
"api_key": "${env.TOGETHER_API_KEY:+}",
}

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,

View file

@ -23,7 +23,12 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -33,13 +38,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper

View file

@ -34,9 +34,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
@ -54,7 +51,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
):
return {
"url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
}

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import httpx
from openai import AsyncOpenAI
from openai import APIConnectionError, AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -38,9 +38,13 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -49,12 +53,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
@ -461,7 +459,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
res = await client.models.list()
try:
res = await client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(

View file

@ -40,7 +40,7 @@ class WatsonXConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}",
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:+}",
"project_id": "${env.WATSONX_PROJECT_ID:+}",
}

View file

@ -18,10 +18,16 @@ from llama_stack.apis.inference import (
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -29,14 +35,6 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
GreedySamplingStrategy,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)

View file

@ -55,10 +55,10 @@ class NvidiaPostTrainingConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
"api_key": "${env.NVIDIA_API_KEY:+}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:=default}",
"project_id": "${env.NVIDIA_PROJECT_ID:=test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}",
}

View file

@ -35,6 +35,6 @@ class NVIDIASafetyConfig(BaseModel):
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:self-check}",
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:=http://localhost:7331}",
"config_id": "${env.NVIDIA_GUARDRAILS_CONFIG_ID:=self-check}",
}

View file

@ -22,6 +22,6 @@ class BraveSearchToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
"api_key": "${env.BRAVE_SEARCH_API_KEY:+}",
"max_results": 3,
}

View file

@ -22,6 +22,6 @@ class TavilySearchToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return {
"api_key": "${env.TAVILY_SEARCH_API_KEY:}",
"api_key": "${env.TAVILY_SEARCH_API_KEY:+}",
"max_results": 3,
}

View file

@ -17,5 +17,5 @@ class WolframAlphaToolConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:}",
"api_key": "${env.WOLFRAM_ALPHA_API_KEY:+}",
}

View file

@ -22,8 +22,8 @@ class PGVectorVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(
cls,
host: str = "${env.PGVECTOR_HOST:localhost}",
port: int = "${env.PGVECTOR_PORT:5432}",
host: str = "${env.PGVECTOR_HOST:=localhost}",
port: int = "${env.PGVECTOR_PORT:=5432}",
db: str = "${env.PGVECTOR_DB}",
user: str = "${env.PGVECTOR_USER}",
password: str = "${env.PGVECTOR_PASSWORD}",

View file

@ -70,8 +70,8 @@ class QdrantIndex(EmbeddingIndex):
)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}"
for _i, (chunk, embedding) in enumerate(zip(chunks, embeddings, strict=False)):
chunk_id = chunk.chunk_id
points.append(
PointStruct(
id=convert_id(chunk_id),

View file

@ -23,6 +23,13 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
@ -31,16 +38,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models.models import Model
from llama_stack.apis.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.exceptions import UnsupportedModelError
from llama_stack.log import get_logger

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.models import ModelType
from llama_stack.exceptions import UnsupportedModelError
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -35,7 +35,9 @@ def get_huggingface_repo(model_descriptor: str) -> str | None:
def build_hf_repo_model_entry(
provider_model_id: str, model_descriptor: str, additional_aliases: list[str] | None = None
provider_model_id: str,
model_descriptor: str,
additional_aliases: list[str] | None = None,
) -> ProviderModelEntry:
aliases = [
get_huggingface_repo(model_descriptor),

View file

@ -95,27 +95,25 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
GreedySamplingStrategy,
Message,
SamplingParams,
SystemMessage,
TokenLogProbs,
ToolChoice,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
SamplingParams,
SystemMessage,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import (
from llama_stack.apis.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.models.llama.datatypes import (
@ -1026,7 +1024,9 @@ def openai_messages_to_messages(
return converted_messages
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam]):
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam] | None):
if content is None:
return ""
if isinstance(content, str):
return content
elif isinstance(content, list):

View file

@ -45,8 +45,8 @@ class RedisKVStoreConfig(CommonConfig):
return {
"type": "redis",
"namespace": None,
"host": "${env.REDIS_HOST:localhost}",
"port": "${env.REDIS_PORT:6379}",
"host": "${env.REDIS_HOST:=localhost}",
"port": "${env.REDIS_PORT:=6379}",
}
@ -66,7 +66,7 @@ class SqliteKVStoreConfig(CommonConfig):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
}
@ -84,12 +84,12 @@ class PostgresKVStoreConfig(CommonConfig):
return {
"type": "postgres",
"namespace": None,
"host": "${env.POSTGRES_HOST:localhost}",
"port": "${env.POSTGRES_PORT:5432}",
"db": "${env.POSTGRES_DB:llamastack}",
"user": "${env.POSTGRES_USER:llamastack}",
"password": "${env.POSTGRES_PASSWORD:llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:" + table_name + "}",
"host": "${env.POSTGRES_HOST:=localhost}",
"port": "${env.POSTGRES_PORT:=5432}",
"db": "${env.POSTGRES_DB:=llamastack}",
"user": "${env.POSTGRES_USER:=llamastack}",
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
}
@classmethod
@ -131,12 +131,12 @@ class MongoDBKVStoreConfig(CommonConfig):
return {
"type": "mongodb",
"namespace": None,
"host": "${env.MONGODB_HOST:localhost}",
"port": "${env.MONGODB_PORT:5432}",
"host": "${env.MONGODB_HOST:=localhost}",
"port": "${env.MONGODB_PORT:=5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
}

View file

@ -12,8 +12,7 @@ import uuid
from abc import ABC, abstractmethod
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.files.files import OpenAIFileObject
from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,

View file

@ -7,6 +7,7 @@ import base64
import io
import logging
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
@ -23,12 +24,13 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
log = logging.getLogger(__name__)
@ -148,6 +150,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int, metadata: dict[str, Any]
) -> list[Chunk]:
default_tokenizer = "DEFAULT_TIKTOKEN_TOKENIZER"
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
try:
@ -161,16 +164,32 @@ def make_overlapped_chunks(
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
chunk_id = generate_chunk_id(chunk, text)
chunk_metadata = metadata.copy()
chunk_metadata["chunk_id"] = chunk_id
chunk_metadata["document_id"] = document_id
chunk_metadata["token_count"] = len(toks)
chunk_metadata["metadata_token_count"] = len(metadata_tokens)
backend_chunk_metadata = ChunkMetadata(
chunk_id=chunk_id,
document_id=document_id,
source=metadata.get("source", None),
created_timestamp=metadata.get("created_timestamp", int(time.time())),
updated_timestamp=int(time.time()),
chunk_window=f"{i}-{i + len(toks)}",
chunk_tokenizer=default_tokenizer,
chunk_embedding_model=None, # This will be set in `VectorDBWithIndex.insert_chunks`
content_token_count=len(toks),
metadata_token_count=len(metadata_tokens),
)
# chunk is a string
chunks.append(
Chunk(
content=chunk,
metadata=chunk_metadata,
chunk_metadata=backend_chunk_metadata,
)
)
@ -237,6 +256,9 @@ class VectorDBWithIndex:
for i, c in enumerate(chunks):
if c.embedding is None:
chunks_to_embed.append(c)
if c.chunk_metadata:
c.chunk_metadata.chunk_embedding_model = self.vector_db.embedding_model
c.chunk_metadata.chunk_embedding_dimension = self.vector_db.embedding_dimension
else:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)

View file

@ -50,7 +50,7 @@ class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
return cls(
type="sqlite",
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
db_path="${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
)
@property
@ -78,11 +78,11 @@ class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
def sample_run_config(cls, **kwargs):
return cls(
type="postgres",
host="${env.POSTGRES_HOST:localhost}",
port="${env.POSTGRES_PORT:5432}",
db="${env.POSTGRES_DB:llamastack}",
user="${env.POSTGRES_USER:llamastack}",
password="${env.POSTGRES_PASSWORD:llamastack}",
host="${env.POSTGRES_HOST:=localhost}",
port="${env.POSTGRES_PORT:=5432}",
db="${env.POSTGRES_DB:=llamastack}",
user="${env.POSTGRES_USER:=llamastack}",
password="${env.POSTGRES_PASSWORD:=llamastack}",
)

View file

@ -180,7 +180,7 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont
trace_id = generate_trace_id()
context = TraceContext(BACKGROUND_LOGGER, trace_id)
attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {})
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | (attributes or {})
context.push_span(name, attributes)
CURRENT_TRACE_CONTEXT.set(context)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import hashlib
import uuid
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
hash_input = f"{document_id}:{chunk_text}".encode()
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))