mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 16:39:39 +00:00
Merge branch 'main' into inference_refactor
This commit is contained in:
commit
6a51e2268d
117 changed files with 12698 additions and 2589 deletions
|
|
@ -78,7 +78,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return None
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithChildren
|
||||
from llama_stack.apis.telemetry import QueryCondition, Span
|
||||
|
||||
|
||||
class TelemetryDatasetMixin:
|
||||
|
|
@ -53,19 +53,18 @@ class TelemetryDatasetMixin:
|
|||
spans = []
|
||||
|
||||
for trace in traces:
|
||||
span_tree = await self.get_span_tree(
|
||||
spans_by_id = await self.get_span_tree(
|
||||
span_id=trace.root_span_id,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
def extract_spans(span: SpanWithChildren) -> List[Span]:
|
||||
result = []
|
||||
for span in spans_by_id.values():
|
||||
if span.attributes and all(
|
||||
attr in span.attributes and span.attributes[attr] is not None
|
||||
for attr in attributes_to_return
|
||||
):
|
||||
result.append(
|
||||
spans.append(
|
||||
Span(
|
||||
trace_id=trace.root_span_id,
|
||||
span_id=span.span_id,
|
||||
|
|
@ -77,11 +76,4 @@ class TelemetryDatasetMixin:
|
|||
)
|
||||
)
|
||||
|
||||
for child in span.children:
|
||||
result.extend(extract_spans(child))
|
||||
|
||||
return result
|
||||
|
||||
spans.extend(extract_spans(span_tree))
|
||||
|
||||
return spans
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
|
||||
from llama_stack.apis.telemetry import QueryCondition, SpanWithStatus, Trace
|
||||
|
||||
|
||||
class TraceStore(Protocol):
|
||||
|
|
@ -27,7 +27,7 @@ class TraceStore(Protocol):
|
|||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren: ...
|
||||
) -> Dict[str, SpanWithStatus]: ...
|
||||
|
||||
|
||||
class SQLiteTraceStore(TraceStore):
|
||||
|
|
@ -114,7 +114,7 @@ class SQLiteTraceStore(TraceStore):
|
|||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren:
|
||||
) -> Dict[str, SpanWithStatus]:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
|
|
@ -143,6 +143,7 @@ class SQLiteTraceStore(TraceStore):
|
|||
ORDER BY depth, start_time
|
||||
"""
|
||||
|
||||
spans_by_id = {}
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
||||
|
|
@ -151,12 +152,8 @@ class SQLiteTraceStore(TraceStore):
|
|||
if not rows:
|
||||
raise ValueError(f"Span {span_id} not found")
|
||||
|
||||
# Build span tree
|
||||
spans_by_id = {}
|
||||
root_span = None
|
||||
|
||||
for row in rows:
|
||||
span = SpanWithChildren(
|
||||
span = SpanWithStatus(
|
||||
span_id=row["span_id"],
|
||||
trace_id=row["trace_id"],
|
||||
parent_span_id=row["parent_span_id"],
|
||||
|
|
@ -165,14 +162,8 @@ class SQLiteTraceStore(TraceStore):
|
|||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
attributes=json.loads(row["filtered_attributes"]),
|
||||
status=row["status"].lower(),
|
||||
children=[],
|
||||
)
|
||||
|
||||
spans_by_id[span.span_id] = span
|
||||
|
||||
if span.span_id == span_id:
|
||||
root_span = span
|
||||
elif span.parent_span_id in spans_by_id:
|
||||
spans_by_id[span.parent_span_id].children.append(span)
|
||||
|
||||
return root_span
|
||||
return spans_by_id
|
||||
|
|
|
|||
|
|
@ -41,8 +41,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
"""
|
||||
|
||||
def trace_method(method: Callable) -> Callable:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
is_async = asyncio.iscoroutinefunction(method)
|
||||
is_async_gen = inspect.isasyncgenfunction(method)
|
||||
|
||||
|
|
@ -77,6 +75,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
async def async_gen_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
|
|
@ -92,6 +92,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
|
||||
@wraps(method)
|
||||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
|
|
@ -107,6 +109,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
|
||||
@wraps(method)
|
||||
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue