forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -3,54 +3,53 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockBaseConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
aws_access_key_id: str | None = Field(
|
||||
default=None,
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: Optional[str] = Field(
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None,
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: Optional[str] = Field(
|
||||
aws_session_token: str | None = Field(
|
||||
default=None,
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: Optional[str] = Field(
|
||||
region_name: str | None = Field(
|
||||
default=None,
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: Optional[str] = Field(
|
||||
profile_name: str | None = Field(
|
||||
default=None,
|
||||
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: Optional[int] = Field(
|
||||
total_max_attempts: int | None = Field(
|
||||
default=None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: Optional[str] = Field(
|
||||
retry_mode: str | None = Field(
|
||||
default=None,
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: Optional[float] = Field(
|
||||
connect_timeout: float | None = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: Optional[float] = Field(
|
||||
read_timeout: float | None = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
session_ttl: Optional[int] = Field(
|
||||
session_ttl: int | None = Field(
|
||||
default=3600,
|
||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||
)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
|
@ -85,16 +85,16 @@ def get_valid_schemas(api_str: str):
|
|||
|
||||
|
||||
def validate_dataset_schema(
|
||||
dataset_schema: Dict[str, Any],
|
||||
expected_schemas: List[Dict[str, Any]],
|
||||
dataset_schema: dict[str, Any],
|
||||
expected_schemas: list[dict[str, Any]],
|
||||
):
|
||||
if dataset_schema not in expected_schemas:
|
||||
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
|
||||
|
||||
|
||||
def validate_row_schema(
|
||||
input_row: Dict[str, Any],
|
||||
expected_schemas: List[Dict[str, Any]],
|
||||
input_row: dict[str, Any],
|
||||
expected_schemas: list[dict[str, Any]],
|
||||
):
|
||||
for schema in expected_schemas:
|
||||
if all(key in input_row for key in schema):
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
|
||||
|
||||
def paginate_records(
|
||||
records: List[Dict[str, Any]],
|
||||
records: list[dict[str, Any]],
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.models.llama.sku_types import * # noqa: F403
|
||||
|
||||
|
@ -22,7 +20,7 @@ def is_supported_safety_model(model: Model) -> bool:
|
|||
]
|
||||
|
||||
|
||||
def supported_inference_models() -> List[Model]:
|
||||
def supported_inference_models() -> list[Model]:
|
||||
return [
|
||||
m
|
||||
for m in all_registered_models()
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
@ -31,10 +31,10 @@ class SentenceTransformerEmbeddingMixin:
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
|
@ -64,7 +65,7 @@ class LiteLLMOpenAIMixin(
|
|||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
api_key_from_config: Optional[str],
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
openai_compat_api_base: str | None = None,
|
||||
):
|
||||
|
@ -97,26 +98,26 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError("LiteLLM does not support completion requests")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
|
@ -243,10 +244,10 @@ class LiteLLMOpenAIMixin(
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
@ -261,24 +262,24 @@ class LiteLLMOpenAIMixin(
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
|
@ -309,29 +310,29 @@ class LiteLLMOpenAIMixin(
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
|
@ -365,21 +366,21 @@ class LiteLLMOpenAIMixin(
|
|||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -20,13 +20,13 @@ from llama_stack.providers.utils.inference import (
|
|||
# more closer to the Model class.
|
||||
class ProviderModelEntry(BaseModel):
|
||||
provider_model_id: str
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
llama_model: Optional[str] = None
|
||||
aliases: list[str] = Field(default_factory=list)
|
||||
llama_model: str | None = None
|
||||
model_type: ModelType = ModelType.llm
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
||||
def get_huggingface_repo(model_descriptor: str) -> str | None:
|
||||
for model in all_registered_models():
|
||||
if model.descriptor() == model_descriptor:
|
||||
return model.huggingface_repo
|
||||
|
@ -34,7 +34,7 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
|||
|
||||
|
||||
def build_hf_repo_model_entry(
|
||||
provider_model_id: str, model_descriptor: str, additional_aliases: Optional[List[str]] = None
|
||||
provider_model_id: str, model_descriptor: str, additional_aliases: list[str] | None = None
|
||||
) -> ProviderModelEntry:
|
||||
aliases = [
|
||||
get_huggingface_repo(model_descriptor),
|
||||
|
@ -58,7 +58,7 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
|||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||
def __init__(self, model_entries: list[ProviderModelEntry]):
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
|
@ -72,11 +72,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> Optional[str]:
|
||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
||||
# TODO: why keep a separate llama model mapping?
|
||||
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
|
||||
def get_llama_model(self, provider_model_id: str) -> str | None:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
|
|
|
@ -8,16 +8,9 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterable
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from openai import AsyncStream
|
||||
|
@ -141,24 +134,24 @@ class OpenAICompatCompletionChoiceDelta(BaseModel):
|
|||
|
||||
|
||||
class OpenAICompatLogprobs(BaseModel):
|
||||
text_offset: Optional[List[int]] = None
|
||||
text_offset: list[int] | None = None
|
||||
|
||||
token_logprobs: Optional[List[float]] = None
|
||||
token_logprobs: list[float] | None = None
|
||||
|
||||
tokens: Optional[List[str]] = None
|
||||
tokens: list[str] | None = None
|
||||
|
||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
||||
top_logprobs: list[dict[str, float]] | None = None
|
||||
|
||||
|
||||
class OpenAICompatCompletionChoice(BaseModel):
|
||||
finish_reason: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
delta: Optional[OpenAICompatCompletionChoiceDelta] = None
|
||||
logprobs: Optional[OpenAICompatLogprobs] = None
|
||||
finish_reason: str | None = None
|
||||
text: str | None = None
|
||||
delta: OpenAICompatCompletionChoiceDelta | None = None
|
||||
logprobs: OpenAICompatLogprobs | None = None
|
||||
|
||||
|
||||
class OpenAICompatCompletionResponse(BaseModel):
|
||||
choices: List[OpenAICompatCompletionChoice]
|
||||
choices: list[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||
|
@ -217,8 +210,8 @@ def get_stop_reason(finish_reason: str) -> StopReason:
|
|||
|
||||
|
||||
def convert_openai_completion_logprobs(
|
||||
logprobs: Optional[OpenAICompatLogprobs],
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
logprobs: OpenAICompatLogprobs | None,
|
||||
) -> list[TokenLogProbs] | None:
|
||||
if not logprobs:
|
||||
return None
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
|
@ -235,7 +228,7 @@ def convert_openai_completion_logprobs(
|
|||
return None
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
|
||||
def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenAICompatLogprobs | None):
|
||||
if logprobs is None:
|
||||
return None
|
||||
if isinstance(logprobs, float):
|
||||
|
@ -562,7 +555,7 @@ class UnparseableToolCall(BaseModel):
|
|||
|
||||
|
||||
async def convert_message_to_openai_dict_new(
|
||||
message: Message | Dict,
|
||||
message: Message | dict,
|
||||
) -> OpenAIChatCompletionMessage:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
|
@ -591,14 +584,10 @@ async def convert_message_to_openai_dict_new(
|
|||
# List[...] -> List[...]
|
||||
async def _convert_message_content(
|
||||
content: InterleavedContent,
|
||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
) -> str | Iterable[OpenAIChatCompletionContentPartParam]:
|
||||
async def impl(
|
||||
content_: InterleavedContent,
|
||||
) -> Union[
|
||||
str,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
List[OpenAIChatCompletionContentPartParam],
|
||||
]:
|
||||
) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content_, str):
|
||||
return content_
|
||||
|
@ -670,7 +659,7 @@ async def convert_message_to_openai_dict_new(
|
|||
|
||||
def convert_tool_call(
|
||||
tool_call: ChatCompletionMessageToolCall,
|
||||
) -> Union[ToolCall, UnparseableToolCall]:
|
||||
) -> ToolCall | UnparseableToolCall:
|
||||
"""
|
||||
Convert a ChatCompletionMessageToolCall tool call to either a
|
||||
ToolCall or UnparseableToolCall. Returns an UnparseableToolCall
|
||||
|
@ -846,7 +835,7 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
|||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||
def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None = None) -> ToolConfig:
|
||||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
try:
|
||||
|
@ -857,7 +846,7 @@ def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[st
|
|||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
|
||||
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||
lls_tools = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
@ -903,8 +892,8 @@ def _convert_openai_request_response_format(
|
|||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
tool_calls: list[OpenAIChatCompletionMessageToolCall],
|
||||
) -> list[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
||||
|
@ -940,7 +929,7 @@ def _convert_openai_tool_calls(
|
|||
|
||||
def _convert_openai_logprobs(
|
||||
logprobs: OpenAIChoiceLogprobs,
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
) -> list[TokenLogProbs] | None:
|
||||
"""
|
||||
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
|
||||
|
||||
|
@ -973,9 +962,9 @@ def _convert_openai_logprobs(
|
|||
|
||||
|
||||
def _convert_openai_sampling_params(
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
) -> SamplingParams:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
|
@ -998,8 +987,8 @@ def _convert_openai_sampling_params(
|
|||
|
||||
|
||||
def openai_messages_to_messages(
|
||||
messages: List[OpenAIChatCompletionMessage],
|
||||
) -> List[Message]:
|
||||
messages: list[OpenAIChatCompletionMessage],
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
"""
|
||||
|
@ -1027,7 +1016,7 @@ def openai_messages_to_messages(
|
|||
return converted_messages
|
||||
|
||||
|
||||
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
|
||||
def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionContentPartParam]):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
|
@ -1273,24 +1262,24 @@ class OpenAICompletionToLlamaStackMixin:
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if stream:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support streaming openai completions")
|
||||
|
@ -1342,29 +1331,29 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIChatCompletionMessage],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages: list[OpenAIChatCompletionMessage],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
messages = openai_messages_to_messages(messages)
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
|
@ -1403,7 +1392,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
async def _process_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||
outstanding_responses: list[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
|
@ -1466,7 +1455,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
i = i + 1
|
||||
|
||||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
|
||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
choices = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
|
|
|
@ -9,7 +9,6 @@ import base64
|
|||
import io
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from PIL import Image as PIL_Image
|
||||
|
@ -63,7 +62,7 @@ log = get_logger(name=__name__, category="inference")
|
|||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: List[RawMessage]
|
||||
messages: list[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
|
@ -93,8 +92,8 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
|
|||
|
||||
|
||||
async def convert_request_to_raw(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
) -> ChatCompletionRequestWithRawContent | CompletionRequestWithRawContent:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
|
@ -170,18 +169,18 @@ def content_has_media(content: InterleavedContent):
|
|||
return _has_media_content(content)
|
||||
|
||||
|
||||
def messages_have_media(messages: List[Message]):
|
||||
def messages_have_media(messages: list[Message]):
|
||||
return any(content_has_media(m.content) for m in messages)
|
||||
|
||||
|
||||
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||
def request_has_media(request: ChatCompletionRequest | CompletionRequest):
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
return messages_have_media(request.messages)
|
||||
else:
|
||||
return content_has_media(request.content)
|
||||
|
||||
|
||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
async def localize_image_content(media: ImageContentItem) -> tuple[bytes, str]:
|
||||
image = media.image
|
||||
if image.url and image.url.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
@ -228,7 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
|||
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest,
|
||||
) -> Tuple[str, int]:
|
||||
) -> tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -265,7 +264,7 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
|
|||
|
||||
async def chat_completion_request_to_model_input_info(
|
||||
request: ChatCompletionRequest, llama_model: str
|
||||
) -> Tuple[str, int]:
|
||||
) -> tuple[str, int]:
|
||||
messages = chat_completion_request_to_messages(request, llama_model)
|
||||
request.messages = messages
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -284,7 +283,7 @@ async def chat_completion_request_to_model_input_info(
|
|||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
llama_model: str,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
|
@ -323,7 +322,7 @@ def chat_completion_request_to_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def response_format_prompt(fmt: Optional[ResponseFormat]):
|
||||
def response_format_prompt(fmt: ResponseFormat | None):
|
||||
if not fmt:
|
||||
return None
|
||||
|
||||
|
@ -337,7 +336,7 @@ def response_format_prompt(fmt: Optional[ResponseFormat]):
|
|||
|
||||
def augment_messages_for_tools_llama_3_1(
|
||||
request: ChatCompletionRequest,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
|
@ -406,7 +405,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
def augment_messages_for_tools_llama(
|
||||
request: ChatCompletionRequest,
|
||||
custom_tool_prompt_generator,
|
||||
) -> List[Message]:
|
||||
) -> list[Message]:
|
||||
existing_messages = request.messages
|
||||
existing_system_message = None
|
||||
if existing_messages[0].role == Role.system.value:
|
||||
|
@ -457,7 +456,7 @@ def augment_messages_for_tools_llama(
|
|||
return messages
|
||||
|
||||
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
|
||||
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
|
||||
if tool_choice == ToolChoice.auto:
|
||||
return ""
|
||||
elif tool_choice == ToolChoice.required:
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class KVStore(Protocol):
|
||||
# TODO: make the value type bytes instead of str
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: ...
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None: ...
|
||||
|
||||
async def get(self, key: str) -> Optional[str]: ...
|
||||
async def get(self, key: str) -> str | None: ...
|
||||
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]: ...
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
|
|
@ -6,10 +6,9 @@
|
|||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
@ -22,7 +21,7 @@ class KVStoreType(Enum):
|
|||
|
||||
|
||||
class CommonConfig(BaseModel):
|
||||
namespace: Optional[str] = Field(
|
||||
namespace: str | None = Field(
|
||||
default=None,
|
||||
description="All keys will be prefixed with this namespace",
|
||||
)
|
||||
|
@ -69,7 +68,7 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
port: int = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: Optional[str] = None
|
||||
password: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
|
@ -108,7 +107,7 @@ class MongoDBKVStoreConfig(CommonConfig):
|
|||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str = None
|
||||
password: Optional[str] = None
|
||||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
|
@ -126,6 +125,6 @@ class MongoDBKVStoreConfig(CommonConfig):
|
|||
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig],
|
||||
RedisKVStoreConfig | SqliteKVStoreConfig | PostgresKVStoreConfig | MongoDBKVStoreConfig,
|
||||
Field(discriminator="type", default=KVStoreType.sqlite.value),
|
||||
]
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .api import KVStore
|
||||
from .config import KVStoreConfig, KVStoreType
|
||||
|
@ -21,13 +20,13 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pymongo import AsyncMongoClient
|
||||
|
||||
|
@ -43,12 +42,12 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
update_query = {"$set": {"value": value, "expiration": expiration}}
|
||||
await self.collection.update_one({"key": key}, update_query, upsert=True)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
query = {"key": key}
|
||||
result = await self.collection.find_one(query, {"value": 1, "_id": 0})
|
||||
|
@ -58,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
key = self._namespaced_key(key)
|
||||
await self.collection.delete_one({"key": key})
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
|
@ -54,7 +53,7 @@ class PostgresKVStoreImpl(KVStore):
|
|||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
|
@ -66,7 +65,7 @@ class PostgresKVStoreImpl(KVStore):
|
|||
(key, value, expiration),
|
||||
)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
|
@ -86,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
|
|||
(key,),
|
||||
)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
@ -25,13 +24,13 @@ class RedisKVStoreImpl(KVStore):
|
|||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
|
@ -43,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
cursor = 0
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
@ -33,7 +32,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
|
@ -41,7 +40,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> str | None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
@ -55,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
|
|
|
@ -9,7 +9,7 @@ import logging
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
|
@ -94,7 +94,7 @@ def content_from_data(data_url: str) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent:
|
||||
def concat_interleaved_content(content: list[InterleavedContent]) -> InterleavedContent:
|
||||
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
|
||||
|
||||
ret = []
|
||||
|
@ -141,7 +141,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
|
|||
return interleaved_content_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> List[Chunk]:
|
||||
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
|
||||
|
@ -165,7 +165,7 @@ def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap
|
|||
|
||||
class EmbeddingIndex(ABC):
|
||||
@abstractmethod
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
@ -185,7 +185,7 @@ class VectorDBWithIndex:
|
|||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
chunks: list[Chunk],
|
||||
) -> None:
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.vector_db.embedding_model, [x.content for x in chunks]
|
||||
|
@ -197,7 +197,7 @@ class VectorDBWithIndex:
|
|||
async def query_chunks(
|
||||
self,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
if params is None:
|
||||
params = {}
|
||||
|
|
|
@ -8,9 +8,10 @@ import abc
|
|||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -38,7 +39,7 @@ class JobArtifact(BaseModel):
|
|||
name: str
|
||||
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
||||
uri: str | None = None
|
||||
metadata: Dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
JobHandler = Callable[
|
||||
|
@ -46,7 +47,7 @@ JobHandler = Callable[
|
|||
]
|
||||
|
||||
|
||||
LogMessage: TypeAlias = Tuple[datetime, str]
|
||||
LogMessage: TypeAlias = tuple[datetime, str]
|
||||
|
||||
|
||||
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||
|
@ -60,7 +61,7 @@ class Job:
|
|||
self._handler = handler
|
||||
self._artifacts: list[JobArtifact] = []
|
||||
self._logs: list[LogMessage] = []
|
||||
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||
self._state_transitions: list[tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler:
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import statistics
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import AggregationFunctionType
|
||||
|
||||
|
||||
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_accuracy(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
|
@ -21,14 +21,14 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
|
|||
}
|
||||
|
||||
|
||||
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
return {
|
||||
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
|
||||
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_weighted_average(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
return {
|
||||
"weighted_average": sum(
|
||||
result["score"] * result["weight"]
|
||||
|
@ -40,14 +40,14 @@ def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[
|
|||
|
||||
|
||||
def aggregate_categorical_count(
|
||||
scoring_results: List[ScoringResultRow],
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow],
|
||||
) -> dict[str, Any]:
|
||||
scores = [str(r["score"]) for r in scoring_results]
|
||||
unique_scores = sorted(set(scores))
|
||||
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
|
||||
|
||||
|
||||
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
def aggregate_median(scoring_results: list[ScoringResultRow]) -> dict[str, Any]:
|
||||
scores = [r["score"] for r in scoring_results if r["score"] is not None]
|
||||
median = statistics.median(scores) if scores else None
|
||||
return {"median": median}
|
||||
|
@ -64,8 +64,8 @@ AGGREGATION_FUNCTIONS = {
|
|||
|
||||
|
||||
def aggregate_metrics(
|
||||
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow], metrics: list[AggregationFunctionType]
|
||||
) -> dict[str, Any]:
|
||||
agg_results = {}
|
||||
for metric in metrics:
|
||||
if metric not in AGGREGATION_FUNCTIONS:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
|
@ -28,28 +28,28 @@ class BaseScoringFn(ABC):
|
|||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(
|
||||
self,
|
||||
scoring_results: List[ScoringResultRow],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> list[ScoringResultRow]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
|
||||
def get_supported_scoring_fn_defs(self) -> list[ScoringFn]:
|
||||
return list(self.supported_fn_defs_registry.values())
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
|
@ -81,18 +81,18 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
@abstractmethod
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
scoring_results: List[ScoringResultRow],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> Dict[str, Any]:
|
||||
scoring_results: list[ScoringResultRow],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> dict[str, Any]:
|
||||
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
|
||||
if scoring_params is not None:
|
||||
if params is None:
|
||||
|
@ -107,8 +107,8 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_fn_identifier: str | None = None,
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> list[ScoringResultRow]:
|
||||
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
import contextlib
|
||||
import signal
|
||||
from collections.abc import Iterator
|
||||
from types import FrameType
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
|
@ -15,7 +15,7 @@ class TimeoutError(Exception):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def time_limit(seconds: float) -> Iterator[None]:
|
||||
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
|
||||
def signal_handler(signum: int, frame: FrameType | None) -> None:
|
||||
raise TimeoutError("Timed out!")
|
||||
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
|
||||
|
@ -17,10 +16,10 @@ class TelemetryDatasetMixin:
|
|||
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
attributes_to_save: List[str],
|
||||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_save: list[str],
|
||||
dataset_id: str,
|
||||
max_depth: Optional[int] = None,
|
||||
max_depth: int | None = None,
|
||||
) -> None:
|
||||
if self.datasetio_api is None:
|
||||
raise RuntimeError("DatasetIO API not available")
|
||||
|
@ -48,9 +47,9 @@ class TelemetryDatasetMixin:
|
|||
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
attributes_to_return: List[str],
|
||||
max_depth: Optional[int] = None,
|
||||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_return: list[str],
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpansResponse:
|
||||
traces = await self.query_traces(attribute_filters=attribute_filters)
|
||||
spans = []
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
@ -16,18 +16,18 @@ from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Tra
|
|||
class TraceStore(Protocol):
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> list[Trace]: ...
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> Dict[str, SpanWithStatus]: ...
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> dict[str, SpanWithStatus]: ...
|
||||
|
||||
|
||||
class SQLiteTraceStore(TraceStore):
|
||||
|
@ -36,11 +36,11 @@ class SQLiteTraceStore(TraceStore):
|
|||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> list[Trace]:
|
||||
def build_where_clause() -> tuple[str, list]:
|
||||
if not attribute_filters:
|
||||
return "", []
|
||||
|
@ -112,9 +112,9 @@ class SQLiteTraceStore(TraceStore):
|
|||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> Dict[str, SpanWithStatus]:
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> dict[str, SpanWithStatus]:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
|
|
|
@ -7,8 +7,9 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import wraps
|
||||
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -25,13 +26,13 @@ def _prepare_for_json(value: Any) -> str:
|
|||
"""Serialize a single value into JSON-compatible format."""
|
||||
if value is None:
|
||||
return ""
|
||||
elif isinstance(value, (str, int, float, bool)):
|
||||
elif isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
elif hasattr(value, "_name_"):
|
||||
return value._name_
|
||||
elif isinstance(value, BaseModel):
|
||||
return json.loads(value.model_dump_json())
|
||||
elif isinstance(value, (list, tuple, set)):
|
||||
elif isinstance(value, list | tuple | set):
|
||||
return [_prepare_for_json(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {str(k): _prepare_for_json(v) for k, v in value.items()}
|
||||
|
@ -43,7 +44,7 @@ def _prepare_for_json(value: Any) -> str:
|
|||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||
def trace_protocol(cls: type[T]) -> type[T]:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
|
|
|
@ -10,9 +10,10 @@ import logging
|
|||
import queue
|
||||
import random
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
LogSeverity,
|
||||
|
@ -106,13 +107,13 @@ class BackgroundLogger:
|
|||
|
||||
|
||||
class TraceContext:
|
||||
spans: List[Span] = []
|
||||
spans: list[Span] = []
|
||||
|
||||
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
|
||||
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
|
||||
def push_span(self, name: str, attributes: dict[str, Any] = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_span_id(),
|
||||
|
@ -168,7 +169,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
|||
root_logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||
async def start_trace(name: str, attributes: dict[str, Any] = None) -> TraceContext:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
|
@ -246,7 +247,7 @@ class TelemetryHandler(logging.Handler):
|
|||
|
||||
|
||||
class SpanContextManager:
|
||||
def __init__(self, name: str, attributes: Dict[str, Any] = None):
|
||||
def __init__(self, name: str, attributes: dict[str, Any] = None):
|
||||
self.name = name
|
||||
self.attributes = attributes
|
||||
self.span = None
|
||||
|
@ -316,11 +317,11 @@ class SpanContextManager:
|
|||
return wrapper
|
||||
|
||||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
def span(name: str, attributes: dict[str, Any] = None):
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Optional[Span]:
|
||||
def get_current_span() -> Span | None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
if CURRENT_TRACE_CONTEXT is None:
|
||||
logger.debug("No trace context to get current span")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue