llama-stack/llama_stack/apis/inference/inference.py
Dinesh Yeduguru fcd6449519
Telemetry API redesign (#525)
# What does this PR do?
Change the Telemetry API to be able to support different use cases like
returning traces for the UI and ability to export for Evals.
Other changes:
* Add a new trace_protocol decorator to decorate all our API methods so
that any call to them will automatically get traced across all impls.
* There is some issue with the decorator pattern of span creation when
using async generators, where there are multiple yields with in the same
context. I think its much more explicit by using the explicit context
manager pattern using with. I moved the span creations in agent instance
to be using with
* Inject session id at the turn level, which should quickly give us all
traces across turns for a given session

Addresses #509

## Test Plan
```
llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml
PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000


 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \
-H 'Content-Type: application/json' \
-d '{
  "attribute_filters": [
    {
      "key": "session_id",
      "op": "eq",
      "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }],
  "limit": 100,
  "offset": 0,
  "order_by": ["start_time"]
}' | jq .
[
  {
    "trace_id": "6902f54b83b4b48be18a6f422b13e16f",
    "root_span_id": "5f37b85543afc15a",
    "start_time": "2024-12-04T08:08:30.501587",
    "end_time": "2024-12-04T08:08:36.026463"
  },
  {
    "trace_id": "92227dac84c0615ed741be393813fb5f",
    "root_span_id": "af7c5bb46665c2c8",
    "start_time": "2024-12-04T08:08:36.031170",
    "end_time": "2024-12-04T08:08:41.693301"
  },
  {
    "trace_id": "7d578a6edac62f204ab479fba82f77b6",
    "root_span_id": "1d935e3362676896",
    "start_time": "2024-12-04T08:08:41.695204",
    "end_time": "2024-12-04T08:08:47.228016"
  },
  {
    "trace_id": "dbd767d76991bc816f9f078907dc9ff2",
    "root_span_id": "f5a7ee76683b9602",
    "start_time": "2024-12-04T08:08:47.234578",
    "end_time": "2024-12-04T08:08:53.189412"
  }
]


curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \
-H 'Content-Type: application/json' \
-d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq .
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   875  100   790  100    85  18462   1986 --:--:-- --:--:-- --:--:-- 20833
{
  "span_id": "6cceb4b48a156913",
  "trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
  "parent_span_id": "892a66d726c7f990",
  "name": "retrieve_rag_context",
  "start_time": "2024-12-04T09:28:21.781995",
  "end_time": "2024-12-04T09:28:21.913352",
  "attributes": {
    "input": [
      "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}",
      "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}"
    ]
  },
  "children": [
    {
      "span_id": "1a2df181854064a8",
      "trace_id": "dafa796f6aaf925f511c04cd7c67fdda",
      "parent_span_id": "6cceb4b48a156913",
      "name": "MemoryRouter.query_documents",
      "start_time": "2024-12-04T09:28:21.787620",
      "end_time": "2024-12-04T09:28:21.906512",
      "attributes": {
        "input": null
      },
      "children": [],
      "status": "ok"
    }
  ],
  "status": "ok"
}

```

<img width="1677" alt="Screenshot 2024-12-04 at 9 42 56 AM"
src="https://github.com/user-attachments/assets/4d3cea93-05ce-415a-93d9-4b1628631bf8">
2024-12-04 11:22:45 -08:00

262 lines
6.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import (
AsyncIterator,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.distribution.tracing import trace_protocol
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
class LogProbConfig(BaseModel):
top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal[QuantizationType.fp8.value] = QuantizationType.fp8.value
@json_schema_type
class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
@json_schema_type
class Int4QuantizationConfig(BaseModel):
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
Field(discriminator="type"),
]
@json_schema_type
class ChatCompletionResponseEventType(Enum):
start = "start"
complete = "complete"
progress = "progress"
@json_schema_type
class ToolCallParseStatus(Enum):
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
@json_schema_type
class ToolCallDelta(BaseModel):
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
@json_schema_type
class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event."""
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]] = None
stop_reason: Optional[StopReason] = None
class ResponseFormatType(Enum):
json_schema = "json_schema"
grammar = "grammar"
class JsonSchemaResponseFormat(BaseModel):
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
json_schema: Dict[str, Any]
class GrammarResponseFormat(BaseModel):
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
@json_schema_type
class CompletionRequest(BaseModel):
model: str
content: InterleavedTextMedia
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
"""Completion response."""
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
"""streamed completion response."""
delta: str
stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedTextMedia]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
batch: List[CompletionResponse]
@json_schema_type
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events."""
event: ChatCompletionResponseEvent
@json_schema_type
class ChatCompletionResponse(BaseModel):
"""Chat completion response."""
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: str
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@json_schema_type
class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
@runtime_checkable
@trace_protocol
class Inference(Protocol):
model_store: ModelStore
@webmethod(route="/inference/completion")
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat-completion")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ...
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...