mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
feat: Updating files/content response to return additional fields
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
e12524af85
commit
a19c16428f
143 changed files with 6907 additions and 15104 deletions
|
@ -10,6 +10,16 @@
|
|||
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
||||
|
||||
|
||||
class ResourceNotFoundError(ValueError):
|
||||
"""generic exception for a missing Llama Stack resource"""
|
||||
|
||||
def __init__(self, resource_name: str, resource_type: str, client_list: str) -> None:
|
||||
message = (
|
||||
f"{resource_type} '{resource_name}' not found. Use '{client_list}' to list available {resource_type}s."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UnsupportedModelError(ValueError):
|
||||
"""raised when model is not present in the list of supported models"""
|
||||
|
||||
|
@ -18,38 +28,32 @@ class UnsupportedModelError(ValueError):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelNotFoundError(ValueError):
|
||||
class ModelNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced model"""
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
message = f"Model '{model_name}' not found. Use client.models.list() to list available models."
|
||||
super().__init__(message)
|
||||
super().__init__(model_name, "Model", "client.models.list()")
|
||||
|
||||
|
||||
class VectorStoreNotFoundError(ValueError):
|
||||
class VectorStoreNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced vector store"""
|
||||
|
||||
def __init__(self, vector_store_name: str) -> None:
|
||||
message = f"Vector store '{vector_store_name}' not found. Use client.vector_dbs.list() to list available vector stores."
|
||||
super().__init__(message)
|
||||
super().__init__(vector_store_name, "Vector Store", "client.vector_dbs.list()")
|
||||
|
||||
|
||||
class DatasetNotFoundError(ValueError):
|
||||
class DatasetNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced dataset"""
|
||||
|
||||
def __init__(self, dataset_name: str) -> None:
|
||||
message = f"Dataset '{dataset_name}' not found. Use client.datasets.list() to list available datasets."
|
||||
super().__init__(message)
|
||||
super().__init__(dataset_name, "Dataset", "client.datasets.list()")
|
||||
|
||||
|
||||
class ToolGroupNotFoundError(ValueError):
|
||||
class ToolGroupNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced tool group"""
|
||||
|
||||
def __init__(self, toolgroup_name: str) -> None:
|
||||
message = (
|
||||
f"Tool group '{toolgroup_name}' not found. Use client.toolgroups.list() to list available tool groups."
|
||||
)
|
||||
super().__init__(message)
|
||||
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError):
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -15,6 +15,71 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
# OpenAI Categories to return in the response
|
||||
class OpenAICategories(StrEnum):
|
||||
"""
|
||||
Required set of categories in moderations api response
|
||||
"""
|
||||
|
||||
VIOLENCE = "violence"
|
||||
VIOLENCE_GRAPHIC = "violence/graphic"
|
||||
HARRASMENT = "harassment"
|
||||
HARRASMENT_THREATENING = "harassment/threatening"
|
||||
HATE = "hate"
|
||||
HATE_THREATENING = "hate/threatening"
|
||||
ILLICIT = "illicit"
|
||||
ILLICIT_VIOLENT = "illicit/violent"
|
||||
SEXUAL = "sexual"
|
||||
SEXUAL_MINORS = "sexual/minors"
|
||||
SELF_HARM = "self-harm"
|
||||
SELF_HARM_INTENT = "self-harm/intent"
|
||||
SELF_HARM_INSTRUCTIONS = "self-harm/instructions"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObjectResults(BaseModel):
|
||||
"""A moderation object.
|
||||
:param flagged: Whether any of the below categories are flagged.
|
||||
:param categories: A list of the categories, and whether they are flagged or not.
|
||||
:param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to.
|
||||
:param category_scores: A list of the categories along with their scores as predicted by model.
|
||||
Required set of categories that need to be in response
|
||||
- violence
|
||||
- violence/graphic
|
||||
- harassment
|
||||
- harassment/threatening
|
||||
- hate
|
||||
- hate/threatening
|
||||
- illicit
|
||||
- illicit/violent
|
||||
- sexual
|
||||
- sexual/minors
|
||||
- self-harm
|
||||
- self-harm/intent
|
||||
- self-harm/instructions
|
||||
"""
|
||||
|
||||
flagged: bool
|
||||
categories: dict[str, bool] | None = None
|
||||
category_applied_input_types: dict[str, list[str]] | None = None
|
||||
category_scores: dict[str, float] | None = None
|
||||
user_message: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModerationObject(BaseModel):
|
||||
"""A moderation object.
|
||||
:param id: The unique identifier for the moderation request.
|
||||
:param model: The model used to generate the moderation results.
|
||||
:param results: A list of moderation objects
|
||||
"""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
results: list[ModerationObjectResults]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ViolationLevel(Enum):
|
||||
"""Severity level of a safety violation.
|
||||
|
@ -82,3 +147,13 @@ class Safety(Protocol):
|
|||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST")
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
|
||||
:param model: The content moderation model you would like to use.
|
||||
:returns: A moderation object.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -226,10 +226,18 @@ class VectorStoreContent(BaseModel):
|
|||
|
||||
:param type: Content type, currently only "text" is supported
|
||||
:param text: The actual text content
|
||||
:param embedding: (Optional) Embedding vector for the content, if available
|
||||
:param created_timestamp: (Optional) Timestamp when the content was created
|
||||
:param metadata: (Optional) Metadata associated with the content, such as source, author, etc.
|
||||
:param chunk_metadata: (Optional) Metadata associated with the chunk, such as document ID, source, etc.
|
||||
"""
|
||||
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
embedding: list[float] | None = None
|
||||
created_timestamp: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
|
@ -25,14 +26,21 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionWithInputMessages,
|
||||
OpenAIEmbeddingsResponse,
|
||||
|
@ -55,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -119,6 +126,7 @@ class InferenceRouter(Inference):
|
|||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
|
@ -132,7 +140,7 @@ class InferenceRouter(Inference):
|
|||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=time.time(),
|
||||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
|
@ -234,49 +242,26 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
||||
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.chat_completion(**params):
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response = await provider.chat_completion(**params)
|
||||
completion_tokens = await self._count_tokens(
|
||||
[response.completion_message],
|
||||
tool_config.tool_prompt_format,
|
||||
response_stream = await provider.chat_completion(**params)
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response_stream,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
response = await provider.chat_completion(**params)
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
tool_prompt_format=tool_config.tool_prompt_format,
|
||||
)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
|
@ -332,39 +317,20 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
|
||||
prompt_tokens = await self._count_tokens(content)
|
||||
|
||||
response = await provider.completion(**params)
|
||||
if stream:
|
||||
|
||||
async def stream_generator():
|
||||
completion_text = ""
|
||||
async for chunk in await provider.completion(**params):
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
||||
yield chunk
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response = await provider.completion(**params)
|
||||
completion_tokens = await self._count_tokens(response.content)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
metrics = await self._compute_and_log_token_usage(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
return self.stream_tokens_and_compute_metrics(
|
||||
response=response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
model=model,
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
metrics = await self.count_tokens_and_compute_metrics(
|
||||
response=response, prompt_tokens=prompt_tokens, model=model
|
||||
)
|
||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
@ -457,9 +423,29 @@ class InferenceRouter(Inference):
|
|||
prompt_logprobs=prompt_logprobs,
|
||||
suffix=suffix,
|
||||
)
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_completion(**params)
|
||||
if stream:
|
||||
return await provider.openai_completion(**params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
# response_stream = await provider.openai_completion(**params)
|
||||
|
||||
response = await provider.openai_completion(**params)
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
@ -537,18 +523,38 @@ class InferenceRouter(Inference):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
response_stream = await provider.openai_chat_completion(**params)
|
||||
if self.store:
|
||||
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
|
||||
return response_stream
|
||||
else:
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
return response
|
||||
|
||||
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
||||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
@ -625,3 +631,244 @@ class InferenceRouter(Inference):
|
|||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
async def stream_tokens_and_compute_metrics(
|
||||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
complete = False
|
||||
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
else:
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
# if we are done receiving tokens
|
||||
if complete:
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for streaming completion metrics
|
||||
if self.telemetry:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
else:
|
||||
# Fallback if no telemetry
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
yield chunk
|
||||
|
||||
async def count_tokens_and_compute_metrics(
|
||||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
model,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
content = [response.completion_message]
|
||||
else:
|
||||
content = response.content
|
||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for completion metrics
|
||||
if self.telemetry:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
||||
# Fallback if no telemetry
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: Model,
|
||||
messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in response:
|
||||
# Skip None chunks
|
||||
if chunk is None:
|
||||
continue
|
||||
|
||||
# Capture ID and created timestamp from first chunk
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
# Accumulate choice data for final assembly
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(
|
||||
tool_call_delta.function.arguments
|
||||
)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
|
||||
# Compute metrics on final chunk
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
completion_text = ""
|
||||
for choice_data in choices_data.values():
|
||||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
# Store the final assembled completion
|
||||
if id and self.store and messages:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"],
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=func_name, arguments=func_args
|
||||
),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(time.time()),
|
||||
model=model.identifier,
|
||||
object="chat.completion",
|
||||
)
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
|
|
|
@ -10,6 +10,7 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
@ -60,3 +61,41 @@ class SafetyRouter(Safety):
|
|||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
async def get_shield_id(self, model: str) -> str:
|
||||
"""Get Shield id from model (provider_resource_id) of shield."""
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
|
||||
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
|
||||
if not matches:
|
||||
raise ValueError(f"No shield associated with provider_resource id {model}")
|
||||
if len(matches) > 1:
|
||||
raise ValueError(f"Multiple shields associated with provider_resource id {model}")
|
||||
return matches[0]
|
||||
|
||||
shield_id = await get_shield_id(self, model)
|
||||
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
|
||||
response = await provider.run_moderation(
|
||||
input=input,
|
||||
model=model,
|
||||
)
|
||||
self._validate_required_categories_exist(response)
|
||||
|
||||
return response
|
||||
|
||||
def _validate_required_categories_exist(self, response: ModerationObject) -> None:
|
||||
"""Validate the ProviderImpl response contains the required Open AI moderations categories."""
|
||||
required_categories = list(map(str, OpenAICategories))
|
||||
|
||||
categories = response.results[0].categories
|
||||
category_applied_input_types = response.results[0].category_applied_input_types
|
||||
category_scores = response.results[0].category_scores
|
||||
|
||||
for i in [categories, category_applied_input_types, category_scores]:
|
||||
if not set(required_categories).issubset(set(i.keys())):
|
||||
raise ValueError(
|
||||
f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}"
|
||||
)
|
||||
|
|
|
@ -154,6 +154,7 @@ providers:
|
|||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -129,7 +129,7 @@ docker run \
|
|||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
llama stack build --distro nvidia --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
|
|
|
@ -123,7 +123,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=dict(
|
||||
service_name="${env.OTEL_SERVICE_NAME:=\u200b}",
|
||||
sinks="${env.TELEMETRY_SINKS:=console,otel_trace}",
|
||||
otel_trace_endpoint="${env.OTEL_TRACE_ENDPOINT:=http://localhost:4318/v1/traces}",
|
||||
otel_exporter_otlp_endpoint="${env.OTEL_EXPORTER_OTLP_ENDPOINT:=http://localhost:4318/v1/traces}",
|
||||
),
|
||||
)
|
||||
],
|
||||
|
|
|
@ -55,7 +55,7 @@ providers:
|
|||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,otel_trace}
|
||||
otel_trace_endpoint: ${env.OTEL_TRACE_ENDPOINT:=http://localhost:4318/v1/traces}
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=http://localhost:4318/v1/traces}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -154,6 +154,7 @@ providers:
|
|||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
@ -44,6 +42,8 @@ class PandasDataframeDataset:
|
|||
if self.dataset_def.source.type == "uri":
|
||||
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
import pandas
|
||||
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
|
@ -103,6 +103,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
import pandas
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
await dataset_impl.load()
|
||||
|
|
|
@ -71,8 +71,13 @@ class HuggingFacePostTrainingConfig(BaseModel):
|
|||
dpo_beta: float = 0.1
|
||||
use_reference_model: bool = True
|
||||
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||
dpo_output_dir: str = "./checkpoints/dpo"
|
||||
dpo_output_dir: str
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
||||
return {
|
||||
"checkpoint_format": "huggingface",
|
||||
"distributed_backend": None,
|
||||
"device": "cpu",
|
||||
"dpo_output_dir": __distro_dir__ + "/dpo_output",
|
||||
}
|
||||
|
|
|
@ -22,15 +22,8 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
|
@ -85,6 +78,10 @@ class HuggingFacePostTrainingImpl:
|
|||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
|
@ -124,6 +121,10 @@ class HuggingFacePostTrainingImpl:
|
|||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
||||
HFDPOAlignmentSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting HF DPO alignment")
|
||||
|
||||
recipe = HFDPOAlignmentSingleDevice(
|
||||
|
@ -168,7 +169,6 @@ class HuggingFacePostTrainingImpl:
|
|||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
|
@ -195,16 +195,13 @@ class HuggingFacePostTrainingImpl:
|
|||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
|
|
|
@ -23,12 +23,8 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
|
@ -84,6 +80,10 @@ class TorchtunePostTrainingImpl:
|
|||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
|
@ -144,7 +144,6 @@ class TorchtunePostTrainingImpl:
|
|||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
|
@ -171,11 +170,9 @@ class TorchtunePostTrainingImpl:
|
|||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
|
@ -20,6 +22,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
|
@ -67,6 +70,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
|
|||
CAT_ELECTIONS: "S13",
|
||||
CAT_CODE_INTERPRETER_ABUSE: "S14",
|
||||
}
|
||||
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
|
||||
|
||||
OPENAI_TO_LLAMA_CATEGORIES_MAP = {
|
||||
OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.HATE: [CAT_HATE],
|
||||
OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES],
|
||||
OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS],
|
||||
OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT],
|
||||
OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION],
|
||||
OpenAICategories.SELF_HARM: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM],
|
||||
OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE],
|
||||
# These are custom categories that are not in the OpenAI moderation categories
|
||||
"custom/defamation": [CAT_DEFAMATION],
|
||||
"custom/specialized_advice": [CAT_SPECIALIZED_ADVICE],
|
||||
"custom/privacy_violation": [CAT_PRIVACY],
|
||||
"custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY],
|
||||
"custom/weapons": [CAT_INDISCRIMINATE_WEAPONS],
|
||||
"custom/elections": [CAT_ELECTIONS],
|
||||
"custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE],
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
|
||||
|
@ -194,6 +222,34 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return await impl.run(messages)
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
if isinstance(input, list):
|
||||
messages = input.copy()
|
||||
else:
|
||||
messages = [input]
|
||||
|
||||
# convert to user messages format with role
|
||||
messages = [UserMessage(content=m) for m in messages]
|
||||
|
||||
# Determine safety categories based on the model type
|
||||
# For known Llama Guard models, use specific categories
|
||||
if model in LLAMA_GUARD_MODEL_IDS:
|
||||
# Use the mapped model for categories but the original model_id for inference
|
||||
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
||||
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
||||
else:
|
||||
# For unknown models, use default Llama Guard 3 8B categories
|
||||
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
||||
|
||||
impl = LlamaGuardShield(
|
||||
model=model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=self.config.excluded_categories,
|
||||
safety_categories=safety_categories,
|
||||
)
|
||||
|
||||
return await impl.run_moderation(messages)
|
||||
|
||||
|
||||
class LlamaGuardShield:
|
||||
def __init__(
|
||||
|
@ -340,3 +396,117 @@ class LlamaGuardShield:
|
|||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run_moderation(self, messages: list[Message]) -> ModerationObject:
|
||||
if not messages:
|
||||
return self.create_moderation_object(self.model)
|
||||
|
||||
# TODO: Add Image based support for OpenAI Moderations
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_moderation_object(content)
|
||||
|
||||
def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject:
|
||||
"""Create a ModerationObject for either safe or unsafe content.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object.
|
||||
|
||||
Returns:
|
||||
ModerationObject with appropriate configuration
|
||||
"""
|
||||
# Set default values for safe case
|
||||
categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False)
|
||||
category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0)
|
||||
category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()}
|
||||
flagged = False
|
||||
user_message = None
|
||||
metadata = {}
|
||||
|
||||
# Handle unsafe case
|
||||
if unsafe_code:
|
||||
unsafe_code_list = [code.strip() for code in unsafe_code.split(",")]
|
||||
invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP]
|
||||
if invalid_codes:
|
||||
logging.warning(f"Invalid safety codes returned: {invalid_codes}")
|
||||
# just returning safe object, as we don't know what the invalid codes can map to
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Get OpenAI categories for the unsafe codes
|
||||
openai_categories = []
|
||||
for code in unsafe_code_list:
|
||||
llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code]
|
||||
openai_categories.extend(
|
||||
k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l
|
||||
)
|
||||
|
||||
# Update categories for unsafe content
|
||||
categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP}
|
||||
category_applied_input_types = {
|
||||
k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP
|
||||
}
|
||||
flagged = True
|
||||
user_message = CANNED_RESPONSE_TEXT
|
||||
metadata = {"violation_type": unsafe_code_list}
|
||||
|
||||
return ModerationObject(
|
||||
id=f"modr-{uuid.uuid4()}",
|
||||
model=model,
|
||||
results=[
|
||||
ModerationObjectResults(
|
||||
flagged=flagged,
|
||||
categories=categories,
|
||||
category_applied_input_types=category_applied_input_types,
|
||||
category_scores=category_scores,
|
||||
user_message=user_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
|
||||
"""Check if content is safe based on response and unsafe code."""
|
||||
if response.strip() == SAFE_RESPONSE:
|
||||
return True
|
||||
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_moderation_object(self, response: str) -> ModerationObject:
|
||||
response = response.strip()
|
||||
if self.is_content_safe(response):
|
||||
return self.create_moderation_object(self.model)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if not unsafe_code:
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
if self.is_content_safe(response, unsafe_code):
|
||||
return self.create_moderation_object(self.model)
|
||||
else:
|
||||
return self.create_moderation_object(self.model, unsafe_code)
|
||||
|
|
|
@ -28,9 +28,6 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
|
@ -67,7 +64,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
logger.info(f"/r[dim]{key}[/dim]: {value}")
|
||||
logger.info(f"[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
|
|
|
@ -4,10 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
|
@ -110,7 +113,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path))
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True))
|
||||
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
|
@ -126,9 +129,11 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}")
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
logger.debug("DEBUG: Routing MetricEvent to _log_metric")
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
|
@ -188,6 +193,38 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Always log to console if console sink is enabled (debug)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
|
||||
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
# Only try to add to span if we have a valid span_id
|
||||
if event.span_id:
|
||||
try:
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=f"metric.{event.metric}",
|
||||
attributes={
|
||||
"value": event.value,
|
||||
"unit": event.unit,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
# Invalid span_id or span not found, but we already logged to console above
|
||||
pass
|
||||
except Exception:
|
||||
# Lock acquisition failed
|
||||
logger.debug("Failed to acquire lock to add metric to span")
|
||||
|
||||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
if isinstance(event.value, int):
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
|
@ -73,6 +71,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
import datasets as hf_datasets
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
@ -81,6 +81,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
return paginate_records(records, start_index, limit)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
import datasets as hf_datasets
|
||||
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
path, params = parse_hf_params(dataset_def)
|
||||
loaded_dataset = hf_datasets.load_dataset(path, **params)
|
||||
|
|
|
@ -13,7 +13,9 @@ LLM_MODEL_IDS = [
|
|||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
|
|
|
@ -42,8 +42,8 @@ client.initialize()
|
|||
### Create Completion
|
||||
|
||||
```python
|
||||
response = client.completion(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||
response = client.inference.completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
content="Complete the sentence using one word: Roses are red, violets are :",
|
||||
stream=False,
|
||||
sampling_params={
|
||||
|
@ -56,8 +56,8 @@ print(f"Response: {response.content}")
|
|||
### Create Chat Completion
|
||||
|
||||
```python
|
||||
response = client.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||
response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -78,8 +78,10 @@ print(f"Response: {response.completion_message.content}")
|
|||
|
||||
### Create Embeddings
|
||||
```python
|
||||
response = client.embeddings(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"]
|
||||
response = client.inference.embeddings(
|
||||
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
contents=["What is the capital of France?"],
|
||||
task_type="query",
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
```
|
||||
```
|
|
@ -112,7 +112,8 @@ class OllamaInferenceAdapter(
|
|||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
if self._openai_client is None:
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
|
||||
url = self.config.url.rstrip("/")
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
|
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
"""
|
||||
Hybrid search using Milvus's native hybrid search capabilities.
|
||||
|
||||
This implementation uses Milvus's hybrid_search method which combines
|
||||
vector search and BM25 search with configurable reranking strategies.
|
||||
"""
|
||||
search_requests = []
|
||||
|
||||
# nprobe: Controls search accuracy vs performance trade-off
|
||||
# 10 balances these trade-offs for RAG applications
|
||||
search_requests.append(
|
||||
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
|
||||
)
|
||||
|
||||
# drop_ratio_search: Filters low-importance terms to improve search performance
|
||||
# 0.2 balances noise reduction with recall
|
||||
search_requests.append(
|
||||
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
|
||||
)
|
||||
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = (reranker_params or {}).get("alpha", 0.5)
|
||||
rerank = WeightedRanker(alpha, 1 - alpha)
|
||||
else:
|
||||
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||
rerank = RRFRanker(impact_factor)
|
||||
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.hybrid_search,
|
||||
collection_name=self.collection_name,
|
||||
reqs=search_requests,
|
||||
ranker=rerank,
|
||||
limit=k,
|
||||
output_fields=["chunk_content"],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"])
|
||||
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the Milvus collection."""
|
||||
|
|
|
@ -9,12 +9,12 @@ import base64
|
|||
import io
|
||||
from urllib.parse import unquote
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
|
||||
async def get_dataframe_from_uri(uri: str):
|
||||
import pandas
|
||||
|
||||
df = None
|
||||
if uri.endswith(".csv"):
|
||||
# Moving to its own thread to avoid io from blocking the eventloop
|
||||
|
|
|
@ -1,129 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIChoiceLogprobs,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
|
||||
async def stream_and_store_openai_completion(
|
||||
provider_stream: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: str,
|
||||
store: InferenceStore,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Wraps a provider's stream, yields chunks, and stores the full completion at the end.
|
||||
"""
|
||||
id = None
|
||||
created = None
|
||||
choices_data: dict[int, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
async for chunk in provider_stream:
|
||||
if id is None and chunk.id:
|
||||
id = chunk.id
|
||||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
idx = choice_delta.index
|
||||
if idx not in choices_data:
|
||||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
||||
if choice_delta.delta:
|
||||
delta = choice_delta.delta
|
||||
if delta.content:
|
||||
current_choice_data["content_parts"].append(delta.content)
|
||||
if delta.tool_calls:
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
tc_idx = tool_call_delta.index
|
||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
||||
# Initialize with correct structure for _ToolCallBuilderData
|
||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function_name_parts": [],
|
||||
"function_arguments_parts": [],
|
||||
}
|
||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
||||
if tool_call_delta.id:
|
||||
builder["id"] = tool_call_delta.id
|
||||
if tool_call_delta.type:
|
||||
builder["type"] = tool_call_delta.type
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
||||
if tool_call_delta.function.arguments:
|
||||
builder["function_arguments_parts"].append(tool_call_delta.function.arguments)
|
||||
if choice_delta.finish_reason:
|
||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||
# Ensure that we are extending with the correct type
|
||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
||||
yield chunk
|
||||
finally:
|
||||
if id:
|
||||
assembled_choices: list[OpenAIChoice] = []
|
||||
for choice_idx, choice_data in choices_data.items():
|
||||
content_str = "".join(choice_data["content_parts"])
|
||||
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
||||
if choice_data["tool_calls_builder"]:
|
||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||
if tc_build_data["id"]:
|
||||
func_name = "".join(tc_build_data["function_name_parts"])
|
||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
||||
assembled_tool_calls.append(
|
||||
OpenAIChatCompletionToolCall(
|
||||
id=tc_build_data["id"],
|
||||
type=tc_build_data["type"], # No or "function" needed, already set
|
||||
function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args),
|
||||
)
|
||||
)
|
||||
message = OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=content_str if content_str else None,
|
||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
||||
)
|
||||
logprobs_content = choice_data["logprobs_content_parts"]
|
||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
||||
|
||||
assembled_choices.append(
|
||||
OpenAIChoice(
|
||||
finish_reason=choice_data["finish_reason"],
|
||||
index=choice_idx,
|
||||
message=message,
|
||||
logprobs=final_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
final_response = OpenAIChatCompletion(
|
||||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(datetime.now(UTC).timestamp()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
await store.store_chat_completion(final_response, input_messages)
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.files import Files, OpenAIFileObject
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
ChunkMetadata,
|
||||
QueryChunksResponse,
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
|
@ -516,31 +517,68 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
raise ValueError(f"Unsupported filter type: {filter_type}")
|
||||
|
||||
def _chunk_to_vector_store_content(self, chunk: Chunk) -> list[VectorStoreContent]:
|
||||
created_ts = None
|
||||
if chunk.chunk_metadata is not None:
|
||||
created_ts = getattr(chunk.chunk_metadata, "created_timestamp", None)
|
||||
|
||||
metadata_dict = {}
|
||||
if chunk.chunk_metadata:
|
||||
if hasattr(chunk.chunk_metadata, "model_dump"):
|
||||
metadata_dict = chunk.chunk_metadata.model_dump()
|
||||
else:
|
||||
metadata_dict = vars(chunk.chunk_metadata)
|
||||
|
||||
user_metadata = chunk.metadata or {}
|
||||
base_meta = {**metadata_dict, **user_metadata}
|
||||
|
||||
# content is InterleavedContent
|
||||
if isinstance(chunk.content, str):
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=chunk.content,
|
||||
embedding=chunk.embedding,
|
||||
created_timestamp=created_ts,
|
||||
metadata=user_metadata,
|
||||
chunk_metadata=ChunkMetadata(**base_meta) if base_meta else None,
|
||||
)
|
||||
]
|
||||
elif isinstance(chunk.content, list):
|
||||
# TODO: Add support for other types of content
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=item.text,
|
||||
)
|
||||
for item in chunk.content
|
||||
if item.type == "text"
|
||||
]
|
||||
content = []
|
||||
for item in chunk.content:
|
||||
if hasattr(item, "type") and item.type == "text":
|
||||
item_meta = {**base_meta}
|
||||
item_user_meta = getattr(item, "metadata", {}) or {}
|
||||
if item_user_meta:
|
||||
item_meta.update(item_user_meta)
|
||||
|
||||
content.append(
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=item.text,
|
||||
embedding=getattr(item, "embedding", None),
|
||||
created_timestamp=created_ts,
|
||||
metadata=item_user_meta,
|
||||
chunk_metadata=ChunkMetadata(**item_meta) if item_meta else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if chunk.content.type != "text":
|
||||
raise ValueError(f"Unsupported content type: {chunk.content.type}")
|
||||
content_item = chunk.content
|
||||
if content_item.type != "text":
|
||||
raise ValueError(f"Unsupported content type: {content_item.type}")
|
||||
|
||||
item_user_meta = getattr(content_item, "metadata", {}) or {}
|
||||
combined_meta = {**base_meta, **item_user_meta}
|
||||
|
||||
content = [
|
||||
VectorStoreContent(
|
||||
type="text",
|
||||
text=chunk.content.text,
|
||||
text=content_item.text,
|
||||
embedding=getattr(content_item, "embedding", None),
|
||||
created_timestamp=created_ts,
|
||||
metadata=item_user_meta,
|
||||
chunk_metadata=ChunkMetadata(**combined_meta) if combined_meta else None,
|
||||
)
|
||||
]
|
||||
return content
|
||||
|
|
|
@ -302,23 +302,25 @@ class VectorDBWithIndex:
|
|||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
strategy = ranker.get("strategy", "rrf")
|
||||
if strategy == "weighted":
|
||||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||
else:
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||
reranker_params = {"impact_factor": k_value}
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
|
|
|
@ -9,7 +9,9 @@ import contextvars
|
|||
import logging
|
||||
import queue
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
|
@ -30,6 +32,16 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
|||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
|
||||
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
|
||||
if not _fallback_logger.handlers:
|
||||
_fallback_logger.propagate = False
|
||||
_fallback_logger.setLevel(logging.ERROR)
|
||||
_fallback_handler = logging.StreamHandler(sys.stderr)
|
||||
_fallback_handler.setLevel(logging.ERROR)
|
||||
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
||||
_fallback_logger.addHandler(_fallback_handler)
|
||||
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
@ -79,19 +91,32 @@ def generate_trace_id() -> str:
|
|||
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||
BACKGROUND_LOGGER = None
|
||||
|
||||
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
|
||||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 1000):
|
||||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue = queue.Queue(maxsize=capacity)
|
||||
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
|
||||
self.worker_thread.start()
|
||||
self._last_queue_full_log_time: float = 0.0
|
||||
self._dropped_since_last_notice: int = 0
|
||||
|
||||
def log_event(self, event):
|
||||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
logger.error("Log queue is full, dropping event")
|
||||
# Aggregate drops and emit at most once per interval via fallback logger
|
||||
self._dropped_since_last_notice += 1
|
||||
current_time = time.time()
|
||||
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
|
||||
_fallback_logger.error(
|
||||
"Log queue is full; dropped %d events since last notice",
|
||||
self._dropped_since_last_notice,
|
||||
)
|
||||
self._last_queue_full_log_time = current_time
|
||||
self._dropped_since_last_notice = 0
|
||||
|
||||
def _process_logs(self):
|
||||
while True:
|
||||
|
|
|
@ -0,0 +1,383 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import { ContentsAPI, VectorStoreContentItem } from "@/lib/contents-api";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Edit, Save, X, Trash2 } from "lucide-react";
|
||||
import {
|
||||
DetailLoadingView,
|
||||
DetailErrorView,
|
||||
DetailNotFoundView,
|
||||
DetailLayout,
|
||||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
||||
|
||||
export default function ContentDetailPage() {
|
||||
const params = useParams();
|
||||
const router = useRouter();
|
||||
const vectorStoreId = params.id as string;
|
||||
const fileId = params.fileId as string;
|
||||
const contentId = params.contentId as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const getTextFromContent = (content: any): string => {
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
} else if (content && content.type === 'text') {
|
||||
return content.text;
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
const [file, setFile] = useState<VectorStoreFile | null>(null);
|
||||
const [content, setContent] = useState<VectorStoreContentItem | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [editedContent, setEditedContent] = useState("");
|
||||
const [editedMetadata, setEditedMetadata] = useState<Record<string, any>>({});
|
||||
const [isEditingEmbedding, setIsEditingEmbedding] = useState(false);
|
||||
const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId || !fileId || !contentId) return;
|
||||
|
||||
const fetchData = async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const [storeResponse, fileResponse] = await Promise.all([
|
||||
client.vectorStores.retrieve(vectorStoreId),
|
||||
client.vectorStores.files.retrieve(vectorStoreId, fileId),
|
||||
]);
|
||||
|
||||
setStore(storeResponse as VectorStore);
|
||||
setFile(fileResponse as VectorStoreFile);
|
||||
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId);
|
||||
const targetContent = contentsResponse.data.find(c => c.id === contentId);
|
||||
|
||||
if (targetContent) {
|
||||
setContent(targetContent);
|
||||
setEditedContent(getTextFromContent(targetContent.content));
|
||||
setEditedMetadata({ ...targetContent.metadata });
|
||||
setEditedEmbedding(targetContent.embedding || []);
|
||||
} else {
|
||||
throw new Error(`Content ${contentId} not found`);
|
||||
}
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err : new Error("Failed to load content."));
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
fetchData();
|
||||
}, [vectorStoreId, fileId, contentId, client]);
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!content) return;
|
||||
|
||||
try {
|
||||
const updates: { content?: string; metadata?: Record<string, any> } = {};
|
||||
|
||||
if (editedContent !== getTextFromContent(content.content)) {
|
||||
updates.content = editedContent;
|
||||
}
|
||||
|
||||
if (JSON.stringify(editedMetadata) !== JSON.stringify(content.metadata)) {
|
||||
updates.metadata = editedMetadata;
|
||||
}
|
||||
|
||||
if (Object.keys(updates).length > 0) {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates);
|
||||
setContent(updatedContent);
|
||||
}
|
||||
|
||||
setIsEditing(false);
|
||||
} catch (err) {
|
||||
console.error('Failed to update content:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async () => {
|
||||
if (!confirm('Are you sure you want to delete this content?')) return;
|
||||
|
||||
try {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
|
||||
} catch (err) {
|
||||
console.error('Failed to delete content:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setEditedContent(content ? getTextFromContent(content.content) : "");
|
||||
setEditedMetadata({ ...content?.metadata });
|
||||
setEditedEmbedding(content?.embedding || []);
|
||||
setIsEditing(false);
|
||||
setIsEditingEmbedding(false);
|
||||
};
|
||||
|
||||
const title = `Content: ${contentId}`;
|
||||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
|
||||
{ label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` },
|
||||
{ label: contentId },
|
||||
];
|
||||
|
||||
if (error) {
|
||||
return <DetailErrorView title={title} id={contentId} error={error} />;
|
||||
}
|
||||
if (isLoading) {
|
||||
return <DetailLoadingView title={title} />;
|
||||
}
|
||||
if (!content) {
|
||||
return <DetailNotFoundView title={title} id={contentId} />;
|
||||
}
|
||||
|
||||
const mainContent = (
|
||||
<>
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between">
|
||||
<CardTitle>Content</CardTitle>
|
||||
<div className="flex gap-2">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<Button size="sm" onClick={handleSave}>
|
||||
<Save className="h-4 w-4 mr-1" />
|
||||
Save
|
||||
</Button>
|
||||
<Button size="sm" variant="outline" onClick={handleCancel}>
|
||||
<X className="h-4 w-4 mr-1" />
|
||||
Cancel
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Button size="sm" onClick={() => setIsEditing(true)}>
|
||||
<Edit className="h-4 w-4 mr-1" />
|
||||
Edit
|
||||
</Button>
|
||||
<Button size="sm" variant="destructive" onClick={handleDelete}>
|
||||
<Trash2 className="h-4 w-4 mr-1" />
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isEditing ? (
|
||||
<textarea
|
||||
value={editedContent}
|
||||
onChange={(e) => setEditedContent(e.target.value)}
|
||||
className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm"
|
||||
placeholder="Enter content..."
|
||||
/>
|
||||
) : (
|
||||
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md">
|
||||
<pre className="whitespace-pre-wrap font-mono text-sm text-gray-900 dark:text-gray-100">
|
||||
{getTextFromContent(content.content)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between">
|
||||
<CardTitle>Content Embedding</CardTitle>
|
||||
<div className="flex gap-2">
|
||||
{isEditingEmbedding ? (
|
||||
<>
|
||||
<Button size="sm" onClick={() => {
|
||||
setIsEditingEmbedding(false);
|
||||
}}>
|
||||
<Save className="h-4 w-4 mr-1" />
|
||||
Save
|
||||
</Button>
|
||||
<Button size="sm" variant="outline" onClick={() => {
|
||||
setEditedEmbedding(content?.embedding || []);
|
||||
setIsEditingEmbedding(false);
|
||||
}}>
|
||||
<X className="h-4 w-4 mr-1" />
|
||||
Cancel
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<Button size="sm" onClick={() => setIsEditingEmbedding(true)}>
|
||||
<Edit className="h-4 w-4 mr-1" />
|
||||
Edit
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{content?.embedding && content.embedding.length > 0 ? (
|
||||
isEditingEmbedding ? (
|
||||
<div className="space-y-2">
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Embedding ({editedEmbedding.length}D vector):
|
||||
</p>
|
||||
<textarea
|
||||
value={JSON.stringify(editedEmbedding, null, 2)}
|
||||
onChange={(e) => {
|
||||
try {
|
||||
const parsed = JSON.parse(e.target.value);
|
||||
if (Array.isArray(parsed) && parsed.every(v => typeof v === 'number')) {
|
||||
setEditedEmbedding(parsed);
|
||||
}
|
||||
} catch {
|
||||
}
|
||||
}}
|
||||
className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs"
|
||||
placeholder="Enter embedding as JSON array..."
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-2 py-1">
|
||||
{content.embedding.length}D vector
|
||||
</span>
|
||||
</div>
|
||||
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto">
|
||||
<pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100">
|
||||
[{content.embedding.slice(0, 20).map(v => v.toFixed(6)).join(', ')}
|
||||
{content.embedding.length > 20 ? `\n... and ${content.embedding.length - 20} more values` : ''}]
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
No embedding available for this content.
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Metadata</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isEditing ? (
|
||||
<div className="space-y-2">
|
||||
{Object.entries(editedMetadata).map(([key, value]) => (
|
||||
<div key={key} className="flex gap-2">
|
||||
<Input
|
||||
value={key}
|
||||
onChange={(e) => {
|
||||
const newMetadata = { ...editedMetadata };
|
||||
delete newMetadata[key];
|
||||
newMetadata[e.target.value] = value;
|
||||
setEditedMetadata(newMetadata);
|
||||
}}
|
||||
placeholder="Key"
|
||||
className="flex-1"
|
||||
/>
|
||||
<Input
|
||||
value={typeof value === 'string' ? value : JSON.stringify(value)}
|
||||
onChange={(e) => {
|
||||
setEditedMetadata({
|
||||
...editedMetadata,
|
||||
[key]: e.target.value
|
||||
});
|
||||
}}
|
||||
placeholder="Value"
|
||||
className="flex-1"
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
setEditedMetadata({
|
||||
...editedMetadata,
|
||||
['']: ''
|
||||
});
|
||||
}}
|
||||
>
|
||||
Add Field
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
{Object.entries(content.metadata).map(([key, value]) => (
|
||||
<div key={key} className="flex justify-between py-1">
|
||||
<span className="font-medium text-gray-600">{key}:</span>
|
||||
<span className="font-mono text-sm">
|
||||
{typeof value === 'string' ? value : JSON.stringify(value)}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
|
||||
const sidebar = (
|
||||
<PropertiesCard>
|
||||
<PropertyItem label="Content ID" value={contentId} />
|
||||
<PropertyItem label="File ID" value={fileId} />
|
||||
<PropertyItem label="Vector Store ID" value={vectorStoreId} />
|
||||
<PropertyItem label="Object Type" value={content.object} />
|
||||
<PropertyItem
|
||||
label="Created"
|
||||
value={new Date(content.created_timestamp * 1000).toLocaleString()}
|
||||
/>
|
||||
<PropertyItem
|
||||
label="Content Length"
|
||||
value={`${getTextFromContent(content.content).length} chars`}
|
||||
/>
|
||||
{content.metadata.chunk_window && (
|
||||
<PropertyItem
|
||||
label="Position"
|
||||
value={content.metadata.chunk_window}
|
||||
/>
|
||||
)}
|
||||
{file && (
|
||||
<>
|
||||
<PropertyItem label="File Status" value={file.status} />
|
||||
<PropertyItem label="File Usage" value={`${file.usage_bytes} bytes`} />
|
||||
</>
|
||||
)}
|
||||
{store && (
|
||||
<>
|
||||
<PropertyItem label="Store Name" value={store.name || ""} />
|
||||
<PropertyItem
|
||||
label="Provider ID"
|
||||
value={(store.metadata.provider_id as string) || ""}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</PropertiesCard>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<PageBreadcrumb segments={breadcrumbSegments} />
|
||||
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||
</>
|
||||
);
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import { ContentsAPI, VectorStoreContentItem } from "@/lib/contents-api";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Edit, Trash2, Eye } from "lucide-react";
|
||||
import {
|
||||
DetailLoadingView,
|
||||
DetailErrorView,
|
||||
DetailNotFoundView,
|
||||
DetailLayout,
|
||||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCaption,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
|
||||
export default function ContentsListPage() {
|
||||
const params = useParams();
|
||||
const router = useRouter();
|
||||
const vectorStoreId = params.id as string;
|
||||
const fileId = params.fileId as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const getTextFromContent = (content: any): string => {
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
} else if (content && content.type === 'text') {
|
||||
return content.text;
|
||||
}
|
||||
return '';
|
||||
};
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
const [file, setFile] = useState<VectorStoreFile | null>(null);
|
||||
const [contents, setContents] = useState<VectorStoreContentItem[]>([]);
|
||||
const [isLoadingStore, setIsLoadingStore] = useState(true);
|
||||
const [isLoadingFile, setIsLoadingFile] = useState(true);
|
||||
const [isLoadingContents, setIsLoadingContents] = useState(true);
|
||||
const [errorStore, setErrorStore] = useState<Error | null>(null);
|
||||
const [errorFile, setErrorFile] = useState<Error | null>(null);
|
||||
const [errorContents, setErrorContents] = useState<Error | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId) return;
|
||||
|
||||
const fetchStore = async () => {
|
||||
setIsLoadingStore(true);
|
||||
setErrorStore(null);
|
||||
try {
|
||||
const response = await client.vectorStores.retrieve(vectorStoreId);
|
||||
setStore(response as VectorStore);
|
||||
} catch (err) {
|
||||
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
|
||||
} finally {
|
||||
setIsLoadingStore(false);
|
||||
}
|
||||
};
|
||||
fetchStore();
|
||||
}, [vectorStoreId, client]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId || !fileId) return;
|
||||
|
||||
const fetchFile = async () => {
|
||||
setIsLoadingFile(true);
|
||||
setErrorFile(null);
|
||||
try {
|
||||
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
|
||||
setFile(response as VectorStoreFile);
|
||||
} catch (err) {
|
||||
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
|
||||
} finally {
|
||||
setIsLoadingFile(false);
|
||||
}
|
||||
};
|
||||
fetchFile();
|
||||
}, [vectorStoreId, fileId, client]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId || !fileId) return;
|
||||
|
||||
const fetchContents = async () => {
|
||||
setIsLoadingContents(true);
|
||||
setErrorContents(null);
|
||||
try {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId, { limit: 100 });
|
||||
setContents(contentsResponse.data);
|
||||
} catch (err) {
|
||||
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
|
||||
} finally {
|
||||
setIsLoadingContents(false);
|
||||
}
|
||||
};
|
||||
fetchContents();
|
||||
}, [vectorStoreId, fileId, client]);
|
||||
|
||||
const handleDeleteContent = async (contentId: string) => {
|
||||
try {
|
||||
const contentsAPI = new ContentsAPI(client);
|
||||
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
|
||||
setContents(contents.filter(content => content.id !== contentId));
|
||||
} catch (err) {
|
||||
console.error('Failed to delete content:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const handleViewContent = (contentId: string) => {
|
||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`);
|
||||
};
|
||||
|
||||
const title = `Contents in File: ${fileId}`;
|
||||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` },
|
||||
{ label: "Contents" },
|
||||
];
|
||||
|
||||
if (errorStore) {
|
||||
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
|
||||
}
|
||||
if (isLoadingStore) {
|
||||
return <DetailLoadingView title={title} />;
|
||||
}
|
||||
if (!store) {
|
||||
return <DetailNotFoundView title={title} id={vectorStoreId} />;
|
||||
}
|
||||
|
||||
const mainContent = (
|
||||
<>
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Content Chunks ({contents.length})</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingContents ? (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
<Skeleton className="h-4 w-1/2" />
|
||||
</div>
|
||||
) : errorContents ? (
|
||||
<div className="text-destructive text-sm">
|
||||
Error loading contents: {errorContents.message}
|
||||
</div>
|
||||
) : contents.length > 0 ? (
|
||||
<Table>
|
||||
<TableCaption>Contents in this file</TableCaption>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Content ID</TableHead>
|
||||
<TableHead>Content Preview</TableHead>
|
||||
<TableHead>Embedding</TableHead>
|
||||
<TableHead>Position</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{contents.map((content) => (
|
||||
<TableRow key={content.id}>
|
||||
<TableCell className="font-mono text-xs">
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-xs text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
onClick={() => handleViewContent(content.id)}
|
||||
title={content.id}
|
||||
>
|
||||
{content.id.substring(0, 10)}...
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="max-w-md">
|
||||
<p className="text-sm truncate" title={getTextFromContent(content.content)}>
|
||||
{getTextFromContent(content.content)}
|
||||
</p>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="text-xs text-gray-500">
|
||||
{content.embedding && content.embedding.length > 0 ? (
|
||||
<div className="max-w-xs">
|
||||
<span className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5" title={`${content.embedding.length}D vector: [${content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...]`}>
|
||||
[{content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...] ({content.embedding.length}D)
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-gray-400 dark:text-gray-500 italic">No embedding</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-xs text-gray-500">
|
||||
{content.metadata.chunk_window
|
||||
? content.metadata.chunk_window
|
||||
: `${content.metadata.content_length || 0} chars`}
|
||||
</TableCell>
|
||||
<TableCell className="text-xs">
|
||||
{new Date(content.created_timestamp * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 w-6 p-0"
|
||||
title="View content details"
|
||||
onClick={() => handleViewContent(content.id)}
|
||||
>
|
||||
<Eye className="h-3 w-3" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 w-6 p-0"
|
||||
title="Edit content"
|
||||
onClick={() => handleViewContent(content.id)}
|
||||
>
|
||||
<Edit className="h-3 w-3" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 w-6 p-0 text-destructive hover:text-destructive"
|
||||
title="Delete content"
|
||||
onClick={() => handleDeleteContent(content.id)}
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
No contents found for this file.
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
|
||||
const sidebar = (
|
||||
<PropertiesCard>
|
||||
<PropertyItem label="File ID" value={fileId} />
|
||||
<PropertyItem label="Vector Store ID" value={vectorStoreId} />
|
||||
{file && (
|
||||
<>
|
||||
<PropertyItem label="Status" value={file.status} />
|
||||
<PropertyItem
|
||||
label="Created"
|
||||
value={new Date(file.created_at * 1000).toLocaleString()}
|
||||
/>
|
||||
<PropertyItem label="Usage Bytes" value={file.usage_bytes} />
|
||||
<PropertyItem
|
||||
label="Chunking Strategy"
|
||||
value={file.chunking_strategy.type}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{store && (
|
||||
<>
|
||||
<PropertyItem label="Store Name" value={store.name || ""} />
|
||||
<PropertyItem
|
||||
label="Provider ID"
|
||||
value={(store.metadata.provider_id as string) || ""}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</PropertiesCard>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<PageBreadcrumb segments={breadcrumbSegments} />
|
||||
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||
</>
|
||||
);
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile, FileContentResponse } from "llama-stack-client/resources/vector-stores/files";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from '@/components/ui/skeleton';
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { List } from "lucide-react";
|
||||
import {
|
||||
DetailLoadingView,
|
||||
DetailErrorView,
|
||||
DetailNotFoundView,
|
||||
DetailLayout,
|
||||
PropertiesCard,
|
||||
PropertyItem,
|
||||
} from "@/components/layout/detail-layout";
|
||||
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb";
|
||||
|
||||
export default function FileDetailPage() {
|
||||
const params = useParams();
|
||||
const router = useRouter();
|
||||
const vectorStoreId = params.id as string;
|
||||
const fileId = params.fileId as string;
|
||||
const client = useAuthClient();
|
||||
|
||||
const [store, setStore] = useState<VectorStore | null>(null);
|
||||
const [file, setFile] = useState<VectorStoreFile | null>(null);
|
||||
const [contents, setContents] = useState<FileContentResponse | null>(null);
|
||||
const [isLoadingStore, setIsLoadingStore] = useState(true);
|
||||
const [isLoadingFile, setIsLoadingFile] = useState(true);
|
||||
const [isLoadingContents, setIsLoadingContents] = useState(true);
|
||||
const [errorStore, setErrorStore] = useState<Error | null>(null);
|
||||
const [errorFile, setErrorFile] = useState<Error | null>(null);
|
||||
const [errorContents, setErrorContents] = useState<Error | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId) return;
|
||||
|
||||
const fetchStore = async () => {
|
||||
setIsLoadingStore(true);
|
||||
setErrorStore(null);
|
||||
try {
|
||||
const response = await client.vectorStores.retrieve(vectorStoreId);
|
||||
setStore(response as VectorStore);
|
||||
} catch (err) {
|
||||
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store."));
|
||||
} finally {
|
||||
setIsLoadingStore(false);
|
||||
}
|
||||
};
|
||||
fetchStore();
|
||||
}, [vectorStoreId, client]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId || !fileId) return;
|
||||
|
||||
const fetchFile = async () => {
|
||||
setIsLoadingFile(true);
|
||||
setErrorFile(null);
|
||||
try {
|
||||
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId);
|
||||
setFile(response as VectorStoreFile);
|
||||
} catch (err) {
|
||||
setErrorFile(err instanceof Error ? err : new Error("Failed to load file."));
|
||||
} finally {
|
||||
setIsLoadingFile(false);
|
||||
}
|
||||
};
|
||||
fetchFile();
|
||||
}, [vectorStoreId, fileId, client]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!vectorStoreId || !fileId) return;
|
||||
|
||||
const fetchContents = async () => {
|
||||
setIsLoadingContents(true);
|
||||
setErrorContents(null);
|
||||
try {
|
||||
const response = await client.vectorStores.files.content(vectorStoreId, fileId);
|
||||
setContents(response);
|
||||
} catch (err) {
|
||||
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents."));
|
||||
} finally {
|
||||
setIsLoadingContents(false);
|
||||
}
|
||||
};
|
||||
fetchContents();
|
||||
}, [vectorStoreId, fileId, client]);
|
||||
|
||||
const handleViewContents = () => {
|
||||
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`);
|
||||
};
|
||||
|
||||
const title = `File: ${fileId}`;
|
||||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
|
||||
{ label: fileId },
|
||||
];
|
||||
|
||||
if (errorStore) {
|
||||
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />;
|
||||
}
|
||||
if (isLoadingStore) {
|
||||
return <DetailLoadingView title={title} />;
|
||||
}
|
||||
if (!store) {
|
||||
return <DetailNotFoundView title={title} id={vectorStoreId} />;
|
||||
}
|
||||
|
||||
const mainContent = (
|
||||
<>
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>File Information</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingFile ? (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
<Skeleton className="h-4 w-1/2" />
|
||||
</div>
|
||||
) : errorFile ? (
|
||||
<div className="text-destructive text-sm">
|
||||
Error loading file: {errorFile.message}
|
||||
</div>
|
||||
) : file ? (
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<h3 className="text-lg font-medium mb-2">File Details</h3>
|
||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Status:</span>
|
||||
<span className="ml-2">{file.status}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Size:</span>
|
||||
<span className="ml-2">{file.usage_bytes} bytes</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Created:</span>
|
||||
<span className="ml-2">{new Date(file.created_at * 1000).toLocaleString()}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Content Strategy:</span>
|
||||
<span className="ml-2">{file.chunking_strategy.type}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="border-t pt-4">
|
||||
<h3 className="text-lg font-medium mb-3">Actions</h3>
|
||||
<Button
|
||||
onClick={handleViewContents}
|
||||
className="flex items-center gap-2 hover:bg-primary/90 dark:hover:bg-primary/80 hover:scale-105 transition-all duration-200"
|
||||
>
|
||||
<List className="h-4 w-4" />
|
||||
View Contents
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
File not found.
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Content Summary</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingContents ? (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
<Skeleton className="h-4 w-1/2" />
|
||||
</div>
|
||||
) : errorContents ? (
|
||||
<div className="text-destructive text-sm">
|
||||
Error loading content summary: {errorContents.message}
|
||||
</div>
|
||||
) : contents && contents.content.length > 0 ? (
|
||||
<div className="space-y-3">
|
||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Content Items:</span>
|
||||
<span className="ml-2">{contents.content.length}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">Total Characters:</span>
|
||||
<span className="ml-2">{contents.content.reduce((total, item) => total + item.text.length, 0)}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-2">
|
||||
<span className="text-sm font-medium text-gray-600 dark:text-gray-400">Preview:</span>
|
||||
<div className="mt-1 bg-gray-50 dark:bg-gray-800 rounded-md p-3">
|
||||
<p className="text-sm text-gray-900 dark:text-gray-100 line-clamp-3">
|
||||
{contents.content[0]?.text.substring(0, 200)}...
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-gray-500 italic text-sm">
|
||||
No contents found for this file.
|
||||
</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
|
||||
const sidebar = (
|
||||
<PropertiesCard>
|
||||
<PropertyItem label="File ID" value={fileId} />
|
||||
<PropertyItem label="Vector Store ID" value={vectorStoreId} />
|
||||
{file && (
|
||||
<>
|
||||
<PropertyItem label="Status" value={file.status} />
|
||||
<PropertyItem
|
||||
label="Created"
|
||||
value={new Date(file.created_at * 1000).toLocaleString()}
|
||||
/>
|
||||
<PropertyItem label="Usage Bytes" value={file.usage_bytes} />
|
||||
<PropertyItem
|
||||
label="Content Strategy"
|
||||
value={file.chunking_strategy.type}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{store && (
|
||||
<>
|
||||
<PropertyItem label="Store Name" value={store.name || ""} />
|
||||
<PropertyItem
|
||||
label="Provider ID"
|
||||
value={(store.metadata.provider_id as string) || ""}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</PropertiesCard>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<PageBreadcrumb segments={breadcrumbSegments} />
|
||||
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
|
||||
</>
|
||||
);
|
||||
}
|
|
@ -1,16 +1,31 @@
|
|||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import LogsLayout from "@/components/layout/logs-layout";
|
||||
import { useParams, usePathname } from "next/navigation";
|
||||
import {
|
||||
PageBreadcrumb,
|
||||
BreadcrumbSegment,
|
||||
} from "@/components/layout/page-breadcrumb";
|
||||
|
||||
export default function VectorStoresLayout({
|
||||
export default function VectorStoreDetailLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const params = useParams();
|
||||
const pathname = usePathname();
|
||||
const vectorStoreId = params.id as string;
|
||||
|
||||
const breadcrumbSegments: BreadcrumbSegment[] = [
|
||||
{ label: "Vector Stores", href: "/logs/vector-stores" },
|
||||
{ label: `Details (${vectorStoreId})` },
|
||||
];
|
||||
|
||||
const isBaseDetailPage = pathname === `/logs/vector-stores/${vectorStoreId}`;
|
||||
|
||||
return (
|
||||
<LogsLayout sectionLabel="Vector Stores" basePath="/logs/vector-stores">
|
||||
<div className="space-y-4">
|
||||
{isBaseDetailPage && <PageBreadcrumb segments={breadcrumbSegments} />}
|
||||
{children}
|
||||
</LogsLayout>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import type {
|
|||
} from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { usePagination } from "@/hooks/use-pagination";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
|
@ -49,73 +50,92 @@ export default function VectorStoresPage() {
|
|||
}
|
||||
}, [status, hasMore, loadMore]);
|
||||
|
||||
if (status === "loading") {
|
||||
const renderContent = () => {
|
||||
if (status === "loading") {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-8 w-full"/>
|
||||
<Skeleton className="h-4 w-full"/>
|
||||
<Skeleton className="h-4 w-full"/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (status === "error") {
|
||||
return <div className="text-destructive">Error: {error?.message}</div>;
|
||||
}
|
||||
|
||||
if (!stores || stores.length === 0) {
|
||||
return <p>No vector stores found.</p>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<Skeleton className="h-8 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
<Skeleton className="h-4 w-full" />
|
||||
</div>
|
||||
<div className="overflow-auto flex-1 min-h-0">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Completed</TableHead>
|
||||
<TableHead>Cancelled</TableHead>
|
||||
<TableHead>Failed</TableHead>
|
||||
<TableHead>In Progress</TableHead>
|
||||
<TableHead>Total</TableHead>
|
||||
<TableHead>Usage Bytes</TableHead>
|
||||
<TableHead>Provider ID</TableHead>
|
||||
<TableHead>Provider Vector DB ID</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{stores.map((store) => {
|
||||
const fileCounts = store.file_counts;
|
||||
const metadata = store.metadata || {};
|
||||
const providerId = metadata.provider_id ?? "";
|
||||
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||
|
||||
return (
|
||||
<TableRow
|
||||
key={store.id}
|
||||
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||
className="cursor-pointer hover:bg-muted/50"
|
||||
>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
onClick={() =>
|
||||
router.push(`/logs/vector-stores/${store.id}`)
|
||||
}
|
||||
>
|
||||
{store.id}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>{store.name}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(store.created_at * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>{fileCounts.completed}</TableCell>
|
||||
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||
<TableCell>{fileCounts.failed}</TableCell>
|
||||
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||
<TableCell>{fileCounts.total}</TableCell>
|
||||
<TableCell>{store.usage_bytes}</TableCell>
|
||||
<TableCell>{providerId}</TableCell>
|
||||
<TableCell>{providerDbId}</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (status === "error") {
|
||||
return <div className="text-destructive">Error: {error?.message}</div>;
|
||||
}
|
||||
|
||||
if (!stores || stores.length === 0) {
|
||||
return <p>No vector stores found.</p>;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="overflow-auto flex-1 min-h-0">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Completed</TableHead>
|
||||
<TableHead>Cancelled</TableHead>
|
||||
<TableHead>Failed</TableHead>
|
||||
<TableHead>In Progress</TableHead>
|
||||
<TableHead>Total</TableHead>
|
||||
<TableHead>Usage Bytes</TableHead>
|
||||
<TableHead>Provider ID</TableHead>
|
||||
<TableHead>Provider Vector DB ID</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{stores.map((store) => {
|
||||
const fileCounts = store.file_counts;
|
||||
const metadata = store.metadata || {};
|
||||
const providerId = metadata.provider_id ?? "";
|
||||
const providerDbId = metadata.provider_vector_db_id ?? "";
|
||||
|
||||
return (
|
||||
<TableRow
|
||||
key={store.id}
|
||||
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
|
||||
className="cursor-pointer hover:bg-muted/50"
|
||||
>
|
||||
<TableCell>{store.id}</TableCell>
|
||||
<TableCell>{store.name}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(store.created_at * 1000).toLocaleString()}
|
||||
</TableCell>
|
||||
<TableCell>{fileCounts.completed}</TableCell>
|
||||
<TableCell>{fileCounts.cancelled}</TableCell>
|
||||
<TableCell>{fileCounts.failed}</TableCell>
|
||||
<TableCell>{fileCounts.in_progress}</TableCell>
|
||||
<TableCell>{fileCounts.total}</TableCell>
|
||||
<TableCell>{store.usage_bytes}</TableCell>
|
||||
<TableCell>{providerId}</TableCell>
|
||||
<TableCell>{providerDbId}</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
<h1 className="text-2xl font-semibold">Vector Stores</h1>
|
||||
{renderContent()}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
"use client";
|
||||
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
|
||||
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DetailLoadingView,
|
||||
DetailErrorView,
|
||||
|
@ -42,6 +44,11 @@ export function VectorStoreDetailView({
|
|||
id,
|
||||
}: VectorStoreDetailViewProps) {
|
||||
const title = "Vector Store Details";
|
||||
const router = useRouter();
|
||||
|
||||
const handleFileClick = (fileId: string) => {
|
||||
router.push(`/logs/vector-stores/${id}/files/${fileId}`);
|
||||
};
|
||||
|
||||
if (errorStore) {
|
||||
return <DetailErrorView title={title} id={id} error={errorStore} />;
|
||||
|
@ -80,7 +87,15 @@ export function VectorStoreDetailView({
|
|||
<TableBody>
|
||||
{files.map((file) => (
|
||||
<TableRow key={file.id}>
|
||||
<TableCell>{file.id}</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
onClick={() => handleFileClick(file.id)}
|
||||
>
|
||||
{file.id}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>{file.status}</TableCell>
|
||||
<TableCell>
|
||||
{new Date(file.created_at * 1000).toLocaleString()}
|
||||
|
|
112
llama_stack/ui/lib/contents-api.ts
Normal file
112
llama_stack/ui/lib/contents-api.ts
Normal file
|
@ -0,0 +1,112 @@
|
|||
import type { FileContentResponse } from "llama-stack-client/resources/vector-stores/files";
|
||||
import type { LlamaStackClient } from "llama-stack-client";
|
||||
|
||||
export type VectorStoreContent = FileContentResponse.Content;
|
||||
export type VectorStoreContentsResponse = FileContentResponse;
|
||||
|
||||
export interface VectorStoreContentItem {
|
||||
id: string;
|
||||
object: string;
|
||||
created_timestamp: number;
|
||||
vector_store_id: string;
|
||||
file_id: string;
|
||||
content: VectorStoreContent;
|
||||
metadata: Record<string, any>;
|
||||
embedding?: number[];
|
||||
}
|
||||
|
||||
export interface VectorStoreContentDeleteResponse {
|
||||
id: string;
|
||||
object: string;
|
||||
deleted: boolean;
|
||||
}
|
||||
|
||||
export interface VectorStoreListContentsResponse {
|
||||
object: string;
|
||||
data: VectorStoreContentItem[];
|
||||
first_id?: string;
|
||||
last_id?: string;
|
||||
has_more: boolean;
|
||||
}
|
||||
|
||||
export class ContentsAPI {
|
||||
constructor(private client: LlamaStackClient) {}
|
||||
|
||||
async getFileContents(vectorStoreId: string, fileId: string): Promise<VectorStoreContentsResponse> {
|
||||
return this.client.vectorStores.files.content(vectorStoreId, fileId);
|
||||
}
|
||||
|
||||
async getContent(vectorStoreId: string, fileId: string, contentId: string): Promise<VectorStoreContentItem> {
|
||||
const contentsResponse = await this.listContents(vectorStoreId, fileId);
|
||||
const targetContent = contentsResponse.data.find(c => c.id === contentId);
|
||||
|
||||
if (!targetContent) {
|
||||
throw new Error(`Content ${contentId} not found`);
|
||||
}
|
||||
|
||||
return targetContent;
|
||||
}
|
||||
|
||||
async updateContent(
|
||||
vectorStoreId: string,
|
||||
fileId: string,
|
||||
contentId: string,
|
||||
updates: { content?: string; metadata?: Record<string, any> }
|
||||
): Promise<VectorStoreContentItem> {
|
||||
throw new Error("Individual content updates not yet implemented in API");
|
||||
}
|
||||
|
||||
async deleteContent(vectorStoreId: string, fileId: string, contentId: string): Promise<VectorStoreContentDeleteResponse> {
|
||||
throw new Error("Individual content deletion not yet implemented in API");
|
||||
}
|
||||
|
||||
async listContents(
|
||||
vectorStoreId: string,
|
||||
fileId: string,
|
||||
options?: {
|
||||
limit?: number;
|
||||
order?: string;
|
||||
after?: string;
|
||||
before?: string;
|
||||
}
|
||||
): Promise<VectorStoreListContentsResponse> {
|
||||
const fileContents = await this.client.vectorStores.files.content(vectorStoreId, fileId);
|
||||
const contentItems: VectorStoreContentItem[] = [];
|
||||
|
||||
fileContents.content.forEach((content, contentIndex) => {
|
||||
const rawContent = content as any;
|
||||
|
||||
// Extract actual fields from the API response
|
||||
const embedding = rawContent.embedding || undefined;
|
||||
const created_timestamp = rawContent.created_timestamp || rawContent.created_at || Date.now() / 1000;
|
||||
const chunkMetadata = rawContent.chunk_metadata || {};
|
||||
const contentId = rawContent.chunk_metadata?.chunk_id || rawContent.id || `content_${fileId}_${contentIndex}`;
|
||||
const objectType = rawContent.object || 'vector_store.file.content';
|
||||
contentItems.push({
|
||||
id: contentId,
|
||||
object: objectType,
|
||||
created_timestamp: created_timestamp,
|
||||
vector_store_id: vectorStoreId,
|
||||
file_id: fileId,
|
||||
content: content,
|
||||
embedding: embedding,
|
||||
metadata: {
|
||||
...chunkMetadata, // chunk_metadata fields from API
|
||||
content_length: content.type === 'text' ? content.text.length : 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
// apply pagination if needed
|
||||
let filteredItems = contentItems;
|
||||
if (options?.limit) {
|
||||
filteredItems = filteredItems.slice(0, options.limit);
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data: filteredItems,
|
||||
has_more: contentItems.length > (options?.limit || contentItems.length),
|
||||
};
|
||||
}
|
||||
}
|
10
llama_stack/ui/package-lock.json
generated
10
llama_stack/ui/package-lock.json
generated
|
@ -18,7 +18,7 @@
|
|||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": "0.2.16",
|
||||
"llama-stack-client": "0.2.17",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
@ -9926,10 +9926,10 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.2.16",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.16.tgz",
|
||||
"integrity": "sha512-jM7sh1CB5wVumutYb3qfmYJpoTe3IRAa5lm3Us4qO7zVP4tbo3eCE7BOFNWyChpjo9efafUItwogNh28pum9PQ==",
|
||||
"license": "Apache-2.0",
|
||||
"version": "0.2.17",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.17.tgz",
|
||||
"integrity": "sha512-+/fEO8M7XPiVLjhH7ge18i1ijKp4+h3dOkE0C8g2cvGuDUtDYIJlf8NSyr9OMByjiWpCibWU7VOKL50LwGLS3Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
"@types/node-fetch": "^2.6.4",
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": ""0.2.17",
|
||||
"llama-stack-client": "^0.2.17",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue