mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
Merge branch 'refs/heads/main' into chroma
This commit is contained in:
commit
80dc2a6a78
45 changed files with 2288 additions and 291 deletions
|
@ -10,6 +10,16 @@
|
|||
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
||||
|
||||
|
||||
class ResourceNotFoundError(ValueError):
|
||||
"""generic exception for a missing Llama Stack resource"""
|
||||
|
||||
def __init__(self, resource_name: str, resource_type: str, client_list: str) -> None:
|
||||
message = (
|
||||
f"{resource_type} '{resource_name}' not found. Use '{client_list}' to list available {resource_type}s."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UnsupportedModelError(ValueError):
|
||||
"""raised when model is not present in the list of supported models"""
|
||||
|
||||
|
@ -18,38 +28,32 @@ class UnsupportedModelError(ValueError):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelNotFoundError(ValueError):
|
||||
class ModelNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced model"""
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
message = f"Model '{model_name}' not found. Use client.models.list() to list available models."
|
||||
super().__init__(message)
|
||||
super().__init__(model_name, "Model", "client.models.list()")
|
||||
|
||||
|
||||
class VectorStoreNotFoundError(ValueError):
|
||||
class VectorStoreNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced vector store"""
|
||||
|
||||
def __init__(self, vector_store_name: str) -> None:
|
||||
message = f"Vector store '{vector_store_name}' not found. Use client.vector_dbs.list() to list available vector stores."
|
||||
super().__init__(message)
|
||||
super().__init__(vector_store_name, "Vector Store", "client.vector_dbs.list()")
|
||||
|
||||
|
||||
class DatasetNotFoundError(ValueError):
|
||||
class DatasetNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced dataset"""
|
||||
|
||||
def __init__(self, dataset_name: str) -> None:
|
||||
message = f"Dataset '{dataset_name}' not found. Use client.datasets.list() to list available datasets."
|
||||
super().__init__(message)
|
||||
super().__init__(dataset_name, "Dataset", "client.datasets.list()")
|
||||
|
||||
|
||||
class ToolGroupNotFoundError(ValueError):
|
||||
class ToolGroupNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced tool group"""
|
||||
|
||||
def __init__(self, toolgroup_name: str) -> None:
|
||||
message = (
|
||||
f"Tool group '{toolgroup_name}' not found. Use client.toolgroups.list() to list available tool groups."
|
||||
)
|
||||
super().__init__(message)
|
||||
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -15,6 +15,71 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
# OpenAI Categories to return in the response
|
||||
class OpenAICategories(StrEnum):
|
||||
"""
|
||||
Required set of categories in moderations api response
|
||||
"""
|
||||
|
||||
VIOLENCE = "violence"
|
||||
VIOLENCE_GRAPHIC = "violence/graphic"
|
||||
HARRASMENT = "harassment"
|
||||
HARRASMENT_THREATENING = "harassment/threatening"
|
||||
HATE = "hate"
|
||||
HATE_THREATENING = "hate/threatening"
|
||||
ILLICIT = "illicit"
|
||||
ILLICIT_VIOLENT = "illicit/violent"
|
||||
SEXUAL = "sexual"
|
||||
SEXUAL_MINORS = "sexual/minors"
|
||||
SELF_HARM = "self-harm"
|
||||
SELF_HARM_INTENT = "self-harm/intent"
|
||||
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObjectResults(BaseModel):
|
||||
"""A moderation object.
|
||||
:param flagged: Whether any of the below categories are flagged.
|
||||
:param categories: A list of the categories, and whether they are flagged or not.
|
||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response
|
||||
- violence
|
||||
- violence/graphic
|
||||
- harassment
|
||||
- harassment/threatening
|
||||
- hate
|
||||
- hate/threatening
|
||||
- illicit
|
||||
- illicit/violent
|
||||
- sexual
|
||||
- sexual/minors
|
||||
- self-harm
|
||||
- self-harm/intent
|
||||
- self-harm/instructions
|
||||
"""
|
||||
|
||||
flagged: bool
|
||||
categories: dict[str, bool] | None = None
|
||||
category_applied_input_types: dict[str, list[str]] | None = None
|
||||
category_scores: dict[str, float] | None = None
|
||||
user_message: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObject(BaseModel):
|
||||
"""A moderation object.
|
||||
:param id: The unique identifier for the moderation request.
|
||||
:param model: The model used to generate the moderation results.
|
||||
:param results: A list of moderation objects
|
||||
"""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
results: list[ModerationObjectResults]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ViolationLevel(Enum):
|
||||
"""Severity level of a safety violation.
|
||||
|
@ -82,3 +147,13 @@ class Safety(Protocol):
|
|||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST")
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
:returns: A moderation object.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
|
@ -25,14 +26,21 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsResponse,
|
||||
|
@ -55,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -119,6 +126,7 @@ class InferenceRouter(Inference):
|
|||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
|
@ -132,7 +140,7 @@ class InferenceRouter(Inference):
|
|||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
|
@ -234,49 +242,26 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
response_stream = await provider.chat_completion(**params)
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response_stream,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
response = await provider.chat_completion(**params)
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
|
@ -332,39 +317,20 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
response = await provider.completion(**params)
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response, prompt_tokens=prompt_tokens, model=model
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
@ -457,9 +423,29 @@ class InferenceRouter(Inference):
|
|||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
)
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_completion(**params)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
@ -537,18 +523,38 @@ class InferenceRouter(Inference):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if self.store:
|
||||
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
||||
return response_stream
|
||||
else:
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
return response
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
@ -625,3 +631,244 @@ class InferenceRouter(Inference):
|
|||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
async def stream_tokens_and_compute_metrics(
|
||||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
complete = False
|
||||
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
else:
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
# if we are done receiving tokens
|
||||
if complete:
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for streaming completion metrics
|
||||
if self.telemetry:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
else:
|
||||
# Fallback if no telemetry
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
yield chunk
|
||||
|
||||
async def count_tokens_and_compute_metrics(
|
||||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
content = [response.completion_message]
|
||||
else:
|
||||
content = response.content
|
||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for completion metrics
|
||||
if self.telemetry:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
||||
# Fallback if no telemetry
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: Model,
|
||||
messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in response:
|
||||
# Skip None chunks
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
# Capture ID and created timestamp from first chunk
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
# Accumulate choice data for final assembly
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(
|
||||
tool_call_delta.function.arguments
|
||||
)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
|
||||
# Compute metrics on final chunk
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
completion_text = ""
|
||||
for choice_data in choices_data.values():
|
||||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
# Store the final assembled completion
|
||||
if id and self.store and messages:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"],
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=func_name, arguments=func_args
|
||||
),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(time.time()),
|
||||
model=model.identifier,
|
||||
object="chat.completion",
|
||||
)
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
|
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
@ -60,3 +61,41 @@ class SafetyRouter(Safety):
|
|||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def get_shield_id(self, model: str) -> str:
|
||||
"""Get Shield id from model (provider_resource_id) of shield."""
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
|
||||
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||
if not matches:
|
||||
raise ValueError(f"No shield associated with provider_resource id {model}")
|
||||
if len(matches) > 1:
|
||||
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
|
||||
return matches[0]
|
||||
|
||||
shield_id = await get_shield_id(self, model)
|
||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
response = await provider.run_moderation(
|
||||
input=input,
|
||||
model=model,
|
||||
)
|
||||
self._validate_required_categories_exist(response)
|
||||
|
||||
return response
|
||||
|
||||
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
||||
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
||||
required_categories = list(map(str, OpenAICategories))
|
||||
|
||||
categories = response.results[0].categories
|
||||
category_applied_input_types = response.results[0].category_applied_input_types
|
||||
category_scores = response.results[0].category_scores
|
||||
|
||||
for i in [categories, category_applied_input_types, category_scores]:
|
||||
if not set(required_categories).issubset(set(i.keys())):
|
||||
raise ValueError(
|
||||
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
||||
)
|
||||
|
|
|
@ -123,7 +123,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=dict(
|
||||
service_name="${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
sinks="${env.TELEMETRY_SINKS:=console,otel_trace}",
|
||||
otel_trace_endpoint="${env.OTEL_TRACE_ENDPOINT:=http://localhost:4318/v1/traces}",
|
||||
otel_exporter_otlp_endpoint="${env.OTEL_EXPORTER_OTLP_ENDPOINT:=http://localhost:4318/v1/traces}",
|
||||
),
|
||||
)
|
||||
],
|
||||
|
|
|
@ -55,7 +55,7 @@ providers:
|
|||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,otel_trace}
|
||||
otel_trace_endpoint: ${env.OTEL_TRACE_ENDPOINT:=http://localhost:4318/v1/traces}
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=http://localhost:4318/v1/traces}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
|
@ -20,6 +22,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
|
@ -67,6 +70,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
CAT_ELECTIONS: "S13",
|
||||
CAT_CODE_INTERPRETER_ABUSE: "S14",
|
||||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
||||
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HATE: [CAT_HATE],
|
||||
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
||||
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
||||
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
||||
# These are custom categories that are not in the OpenAI moderation categories
|
||||
"custom/defamation": [CAT_DEFAMATION],
|
||||
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
||||
"custom/privacy_violation": [CAT_PRIVACY],
|
||||
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
|
||||
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
|
||||
"custom/elections": [CAT_ELECTIONS],
|
||||
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
|
@ -194,6 +222,34 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await impl.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
if isinstance(input, list):
|
||||
messages = input.copy()
|
||||
else:
|
||||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run_moderation(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
|
@ -340,3 +396,117 @@ class LlamaGuardShield:
|
|||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
||||
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
|
||||
"""Create a ModerationObject for either safe or unsafe content.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
|
||||
|
||||
Returns:
|
||||
ModerationObject with appropriate configuration
|
||||
"""
|
||||
# Set default values for safe case
|
||||
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
# Handle unsafe case
|
||||
if unsafe_code:
|
||||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||
if invalid_codes:
|
||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
# just returning safe object, as we don't know what the invalid codes can map to
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Get OpenAI categories for the unsafe codes
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
|
||||
# Update categories for unsafe content
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
|
||||
}
|
||||
flagged = True
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
metadata = {"violation_type": unsafe_code_list}
|
||||
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||
"""Check if content is safe based on response and unsafe code."""
|
||||
if response.strip() == SAFE_RESPONSE:
|
||||
return True
|
||||
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_moderation_object(self, response: str) -> ModerationObject:
|
||||
response = response.strip()
|
||||
if self.is_content_safe(response):
|
||||
return self.create_moderation_object(self.model)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if not unsafe_code:
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
if self.is_content_safe(response, unsafe_code):
|
||||
return self.create_moderation_object(self.model)
|
||||
else:
|
||||
return self.create_moderation_object(self.model, unsafe_code)
|
||||
|
|
|
@ -28,9 +28,6 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
|
@ -67,7 +64,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
logger.info(f"/r[dim]{key}[/dim]: {value}")
|
||||
logger.info(f"[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
|
|
|
@ -4,10 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
|
@ -110,7 +113,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
|
@ -126,9 +129,11 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}")
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
logger.debug("DEBUG: Routing MetricEvent to _log_metric")
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
|
@ -188,6 +193,38 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Always log to console if console sink is enabled (debug)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
|
||||
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
# Only try to add to span if we have a valid span_id
|
||||
if event.span_id:
|
||||
try:
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=f"metric.{event.metric}",
|
||||
attributes={
|
||||
"value": event.value,
|
||||
"unit": event.unit,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
# Invalid span_id or span not found, but we already logged to console above
|
||||
pass
|
||||
except Exception:
|
||||
# Lock acquisition failed
|
||||
logger.debug("Failed to acquire lock to add metric to span")
|
||||
|
||||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
|
|
|
@ -1,129 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
|
||||
async def stream_and_store_openai_completion(
|
||||
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: str,
|
||||
store: InferenceStore,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
|
||||
"""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in provider_stream:
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
# Initialize with correct structure for _ToolCallBuilderData
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
# Ensure that we are extending with the correct type
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
yield chunk
|
||||
finally:
|
||||
if id:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"], # No or "function" needed, already set
|
||||
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(datetime.now(UTC).timestamp()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
await store.store_chat_completion(final_response, input_messages)
|
|
@ -81,7 +81,7 @@ BACKGROUND_LOGGER = None
|
|||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 1000):
|
||||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": ""0.2.17",
|
||||
"llama-stack-client": "^0.2.17",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue