address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-04 09:25:24 -08:00
parent 32af1f9dd4
commit b8c395c264
15 changed files with 672 additions and 151 deletions

View file

@ -23,7 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.deployment_types import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
@ -428,7 +428,6 @@ class Agents(Protocol):
) -> AgentCreateResponse: ...
@webmethod(route="/agents/turn/create")
@traced(input="messages")
async def create_agent_turn(
self,
agent_id: str,

View file

@ -38,7 +38,7 @@ class DatasetIO(Protocol):
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/upload", method="POST")
async def upload_rows(
@webmethod(route="/datasetio/append-rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -21,7 +21,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
@ -227,7 +227,6 @@ class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
@traced(input="content")
async def completion(
self,
model_id: str,
@ -239,7 +238,6 @@ class Inference(Protocol):
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion")
@traced(input="messages")
async def chat_completion(
self,
model_id: str,
@ -257,7 +255,6 @@ class Inference(Protocol):
]: ...
@webmethod(route="/inference/embeddings")
@traced(input="contents")
async def embeddings(
self,
model_id: str,

View file

@ -16,7 +16,7 @@ from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
@json_schema_type
@ -50,7 +50,6 @@ class Memory(Protocol):
# this will just block now until documents are inserted, but it should
# probably return a Job instance which can be polled for completion
@traced(input="documents")
@webmethod(route="/memory/insert")
async def insert_documents(
self,
@ -60,7 +59,6 @@ class Memory(Protocol):
) -> None: ...
@webmethod(route="/memory/query")
@traced(input="query")
async def query_documents(
self,
bank_id: str,

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.distribution.tracing import trace_protocol, traced
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
@ -50,7 +50,6 @@ class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run-shield")
@traced(input="messages")
async def run_shield(
self,
shield_id: str,

View file

@ -21,6 +21,9 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
@json_schema_type
class SpanStatus(Enum):
@ -147,57 +150,39 @@ class EvalTrace(BaseModel):
@json_schema_type
class MaterializedSpan(Span):
children: List["MaterializedSpan"] = Field(default_factory=list)
class SpanWithChildren(Span):
children: List["SpanWithChildren"] = Field(default_factory=list)
status: Optional[SpanStatus] = None
@json_schema_type
class QueryCondition(BaseModel):
key: str
op: str
op: Literal["eq", "ne", "gt", "lt"]
value: Any
class TraceStore(Protocol):
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_materialized_span(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan: ...
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: ...
async def log_event(
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
) -> None: ...
@webmethod(route="/telemetry/query-traces", method="GET")
@webmethod(route="/telemetry/query-traces", method="POST")
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
@webmethod(route="/telemetry/get-materialized-span", method="GET")
async def get_materialized_span(
@webmethod(route="/telemetry/get-span-tree", method="POST")
async def get_span_tree(
self,
span_id: str,
attribute_keys_to_return: Optional[List[str]] = None,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> MaterializedSpan: ...
) -> SpanWithChildren: ...