Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-16 16:47:57 -08:00
commit 6a51e2268d
117 changed files with 12698 additions and 2589 deletions

View file

@ -78,7 +78,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None
async def register_model(self, model: Model) -> Model:
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:

View file

@ -7,7 +7,7 @@
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithChildren
from llama_stack.apis.telemetry import QueryCondition, Span
class TelemetryDatasetMixin:
@ -53,19 +53,18 @@ class TelemetryDatasetMixin:
spans = []
for trace in traces:
span_tree = await self.get_span_tree(
spans_by_id = await self.get_span_tree(
span_id=trace.root_span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
def extract_spans(span: SpanWithChildren) -> List[Span]:
result = []
for span in spans_by_id.values():
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
):
result.append(
spans.append(
Span(
trace_id=trace.root_span_id,
span_id=span.span_id,
@ -77,11 +76,4 @@ class TelemetryDatasetMixin:
)
)
for child in span.children:
result.extend(extract_spans(child))
return result
spans.extend(extract_spans(span_tree))
return spans

View file

@ -6,11 +6,11 @@
import json
from datetime import datetime
from typing import List, Optional, Protocol
from typing import Dict, List, Optional, Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
from llama_stack.apis.telemetry import QueryCondition, SpanWithStatus, Trace
class TraceStore(Protocol):
@ -27,7 +27,7 @@ class TraceStore(Protocol):
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
) -> Dict[str, SpanWithStatus]: ...
class SQLiteTraceStore(TraceStore):
@ -114,7 +114,7 @@ class SQLiteTraceStore(TraceStore):
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
) -> Dict[str, SpanWithStatus]:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
@ -143,6 +143,7 @@ class SQLiteTraceStore(TraceStore):
ORDER BY depth, start_time
"""
spans_by_id = {}
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
@ -151,12 +152,8 @@ class SQLiteTraceStore(TraceStore):
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = SpanWithChildren(
span = SpanWithStatus(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
@ -165,14 +162,8 @@ class SQLiteTraceStore(TraceStore):
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span
return spans_by_id

View file

@ -41,8 +41,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
"""
def trace_method(method: Callable) -> Callable:
from llama_stack.providers.utils.telemetry import tracing
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
@ -77,6 +75,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
@ -92,6 +92,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
@ -107,6 +109,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)