mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 08:11:59 +00:00
Merge branch 'main' into patch-metadata
This commit is contained in:
commit
f0a142f5a8
21 changed files with 1405 additions and 887 deletions
|
|
@ -820,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||
data: list[OpenAICompletionWithInputMessages]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
class InferenceProvider(Protocol):
|
||||
"""
|
||||
This protocol defines the interface that should be implemented by all inference providers.
|
||||
"""
|
||||
|
||||
API_NAMESPACE: str = "Inference"
|
||||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
|
|
@ -1062,3 +1079,39 @@ class Inference(Protocol):
|
|||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 20,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
||||
:returns: A ListOpenAIChatCompletionResponse.
|
||||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
"""
|
||||
raise NotImplementedError("Get chat completion is not implemented")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import shutil
|
|||
import sys
|
||||
import textwrap
|
||||
from functools import lru_cache
|
||||
from importlib.abc import Traversable
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
|
@ -250,11 +251,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
if args.run:
|
||||
run_config = Path(run_config)
|
||||
config_dict = yaml.safe_load(run_config.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
if not os.path.exists(config.external_providers_dir):
|
||||
os.makedirs(config.external_providers_dir, exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||
run_command(run_args)
|
||||
|
|
@ -264,7 +264,7 @@ def _generate_run_config(
|
|||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
image_name: str,
|
||||
) -> str:
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
|
|
@ -343,7 +343,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: str | None = None,
|
||||
template_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
) -> str:
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
|
|
|
|||
|
|
@ -340,8 +340,17 @@ class BuildConfig(BaseModel):
|
|||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: str | None = Field(
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
distribution_spec=DistributionSpec(
|
||||
providers=provider_types,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
print_pip_install_help(build_config)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
|
|
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
}
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||
return {
|
||||
**api_protocol_map(),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
|
|
@ -302,9 +309,6 @@ async def instantiate_provider(
|
|||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
|
@ -342,6 +346,8 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check()
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
|
|
|
|||
|
|
@ -280,7 +280,18 @@ class TracingMiddleware:
|
|||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
|
|
@ -370,14 +381,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for deprecated argument usage
|
||||
if "--config" in sys.argv:
|
||||
warnings.warn(
|
||||
"The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
log_line = ""
|
||||
if args.config:
|
||||
# if the user provided a config file, use it, even if template was specified
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
|
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
|
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
|
|||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
|
|||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
Event,
|
||||
|
|
@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
|
|||
)
|
||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
|
|
@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
|
|
@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
else:
|
||||
event.attributes["__root_span__"] = "true"
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
|
@ -82,7 +82,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, url: str) -> None:
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import SambaNovaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .sambanova import SambaNovaCompatInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import InferenceProvider
|
||||
|
||||
from .config import TogetherCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> Inference:
|
||||
async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .together import TogetherCompatInferenceAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
|
@ -59,7 +59,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class LiteLLMOpenAIMixin(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
# TODO: avoid exposing the litellm specific model names to the user.
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ logger = get_logger(__name__, category="core")
|
|||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
|
|
@ -178,7 +180,8 @@ async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceCont
|
|||
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
attributes = {marker: True for marker in ROOT_SPAN_MARKERS} | (attributes or {})
|
||||
context.push_span(name, attributes)
|
||||
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue