Merge conflicts

This commit is contained in:
Chantal D Gama Rose 2024-12-03 19:00:27 -08:00
commit 5b027d2de5
198 changed files with 6140 additions and 3477 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class HuggingfaceDatasetIOConfig(BaseModel):

View file

@ -9,6 +9,7 @@ from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -4,11 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
@json_schema_type
class BedrockConfig(BedrockBaseConfig):
pass

View file

@ -4,11 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from ._config import NVIDIAConfig
from ._nvidia import NVIDIAInferenceAdapter
from llama_stack.apis.inference import Inference
from .config import NVIDIAConfig
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter:
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .nvidia import NVIDIAInferenceAdapter
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)

View file

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class NVIDIAConfig(BaseModel):
"""
Configuration for the NVIDIA NIM inference endpoint.
Attributes:
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the url to the
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)

View file

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
ResponseFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from . import NVIDIAConfig
from .openai_utils import (
convert_chat_completion_request,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
_MODEL_ALIASES = [
build_model_alias(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_model_alias(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
]
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. "
"Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
# make sure the client lives longer than any async calls
self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1",
api_key=self._config.api_key or "NO KEY",
timeout=self._config.timeout,
)
def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
raise NotImplementedError()
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
await check_health(self._config) # this raises errors
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
),
n=1,
)
try:
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
if stream:
return convert_openai_chat_completion_stream(response)
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])

View file

@ -0,0 +1,581 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
StopReason,
TokenLogProbs,
ToolCall,
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolResponseMessage,
UserMessage,
)
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
ToolDefinition:
tool_name: str | BuiltinTool
description: Optional[str]
parameters: Optional[Dict[str, ToolParamDefinition]]
ToolParamDefinition:
param_type: str
description: Optional[str]
required: Optional[bool]
default: Optional[Any]
OpenAI spec -
{
"type": "function",
"function": {
"name": tool_name,
"description": description,
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_type,
"description": description,
"default": default,
},
...
},
"required": [param_name, ...],
},
},
}
"""
out = {
"type": "function",
"function": {},
}
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
else:
function.update(name=tool.tool_name)
if tool.description:
function.update(description=tool.description)
if tool.parameters:
parameters = {
"type": "object",
"properties": {},
}
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
if param.description:
properties[param_name].update(description=param.description)
if param.default:
properties[param_name].update(default=param.default)
if param.required:
required.append(param_name)
if required:
parameters.update(required=required)
function.update(parameters=parameters)
return out
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
# users can supply a dict instead of a Message object, we'll
# convert it to a Message object and proceed with some type safety.
if isinstance(message, dict):
if "role" not in message:
raise ValueError("role is required in message")
if message["role"] == "user":
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "ipython":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)
else:
raise ValueError(f"Unsupported message role: {message['role']}")
out: OpenAIChatCompletionMessage = None
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
content=message.content, # TODO(mf): handle image content
)
elif isinstance(message, CompletionMessage):
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=message.content,
tool_calls=[
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=message.content,
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=message.content,
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return out
def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# messages -> messages
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> GrammarResponseFormat TODO(mf)
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(
request.response_format, JsonSchemaResponseFormat
):
raise ValueError(
f"Unsupported response format: {request.response_format}. "
"Only JsonSchemaResponseFormat is supported."
)
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
)
if request.response_format:
# server bug - setting guided_json changes the behavior of response_format resulting in an error
# payload.update(response_format="json_object")
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
)
if request.tool_choice:
payload.update(
tool_choice=request.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
payload.update(logprobs=True)
payload.update(top_logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls", ...]
- stop: model hit a natural stop point or a provided stop sequence
- length: maximum number of tokens specified in the request was reached
- tool_calls: model called a tool
->
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# TODO(mf): are end_of_turn and end_of_message semantics correct?
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
"""
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
OpenAI ChatCompletionMessageToolCall:
id: str
function: Function
type: Literal["function"]
OpenAI Function:
arguments: str
name: str
->
ToolCall:
call_id: str
tool_name: str
arguments: Dict[str, ...]
"""
if not tool_calls:
return [] # CompletionMessage tool_calls is not optional
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
)
for call in tool_calls
]
def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
OpenAI ChoiceLogprobs:
content: Optional[List[ChatCompletionTokenLogprob]]
OpenAI ChatCompletionTokenLogprob:
token: str
logprob: float
top_logprobs: List[TopLogprob]
OpenAI TopLogprob:
token: str
logprob: float
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content
]
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
"""
Convert an OpenAI Choice into a ChatCompletionResponse.
OpenAI Choice:
message: ChatCompletionMessage
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
OpenAI ChatCompletionMessage:
role: Literal["assistant"]
content: Optional[str]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
->
ChatCompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
assert (
hasattr(choice, "message") and choice.message
), "error in server response: message not found"
assert (
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
async def convert_openai_chat_completion_stream(
stream: AsyncStream[OpenAIChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
OpenAI ChatCompletionChunk:
choices: List[Choice]
OpenAI Choice: # different from the non-streamed Choice
delta: ChoiceDelta
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
logprobs: Optional[ChoiceLogprobs]
OpenAI ChoiceDelta:
content: Optional[str]
role: Optional[Literal["system", "user", "assistant", "tool"]]
tool_calls: Optional[List[ChoiceDeltaToolCall]]
OpenAI ChoiceDeltaToolCall:
index: int
id: Optional[str]
function: Optional[ChoiceDeltaToolCallFunction]
type: Optional[Literal["function"]]
OpenAI ChoiceDeltaToolCallFunction:
name: Optional[str]
arguments: Optional[str]
->
ChatCompletionResponseStreamChunk:
event: ChatCompletionResponseEvent
ChatCompletionResponseEvent:
event_type: ChatCompletionResponseEventType
delta: Union[str, ToolCallDelta]
logprobs: Optional[List[TokenLogProbs]]
stop_reason: Optional[StopReason]
ChatCompletionResponseEventType:
start = "start"
progress = "progress"
complete = "complete"
ToolCallDelta:
content: Union[str, ToolCall]
parse_status: ToolCallParseStatus
ToolCall:
call_id: str
tool_name: str
arguments: str
ToolCallParseStatus:
started = "started"
in_progress = "in_progress"
failure = "failure"
success = "success"
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
StopReason:
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> (
Generator[ChatCompletionResponseEventType, None, None]
):
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
# we implement NIM specific semantics, the main difference from OpenAI
# is that tool_calls are always produced as a complete call. there is no
# intermediate / partial tool call streamed. because of this, we can
# simplify the logic and not concern outselves with parse_status of
# started/in_progress/failed. we can always assume success.
#
# a stream of ChatCompletionResponseStreamChunk consists of
# 0. a start event
# 1. zero or more progress events
# - each progress event has a delta
# - each progress event may have a stop_reason
# - each progress event may have logprobs
# - each progress event may have tool_calls
# if a progress event has tool_calls,
# it is fully formed and
# can be emitted with a parse_status of success
# 2. a complete event
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=choice.delta.content,
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest"
)
# NIM only produces fully formed tool calls, so we can assume success
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.success,
),
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=choice.delta.content or "", # content is not optional
logprobs=_convert_openai_logprobs(choice.logprobs),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Tuple
import httpx
from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
async def _get_health(url: str) -> Tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NVIDIAConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
print("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
import httpx
@ -39,6 +40,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
log = logging.getLogger(__name__)
model_aliases = [
build_model_alias(
@ -57,18 +59,26 @@ model_aliases = [
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
@ -81,6 +91,14 @@ model_aliases = [
"llama3.2-vision",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_model_alias(
@ -105,7 +123,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
print(f"checking connectivity to Ollama at `{self.url}`...")
log.info(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:

View file

@ -37,6 +37,18 @@ class InferenceEndpointImplConfig(BaseModel):
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"endpoint_name": endpoint_name,
"api_token": api_token,
}
@json_schema_type
class InferenceAPIImplConfig(BaseModel):
@ -47,3 +59,15 @@ class InferenceAPIImplConfig(BaseModel):
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)
@classmethod
def sample_run_config(
cls,
repo: str = "${env.INFERENCE_MODEL}",
api_token: str = "${env.HF_API_TOKEN}",
**kwargs,
):
return {
"huggingface_repo": repo,
"api_token": api_token,
}

View file

@ -17,6 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -34,7 +38,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
logger = logging.getLogger(__name__)
log = logging.getLogger(__name__)
def build_model_aliases():
return [
build_model_alias(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class _HfAdapter(Inference, ModelsProtocolPrivate):
@ -44,45 +59,39 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def register_model(self, model: Model) -> None:
pass
async def list_models(self) -> List[Model]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
Model(
identifier=identifier,
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
)
]
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -176,7 +185,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@ -186,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -241,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,
@ -256,7 +266,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
@ -264,7 +274,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
print(f"Initializing TGI client with url={config.url}")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]

View file

@ -3,6 +3,8 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
@ -34,6 +36,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
def build_model_aliases():
return [
build_model_alias(
@ -53,7 +58,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
print(f"Initializing VLLM client with base_url={self.config.url}")
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def shutdown(self) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import List
from urllib.parse import urlparse
@ -21,6 +22,8 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
)
log = logging.getLogger(__name__)
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection):
@ -56,10 +59,7 @@ class ChromaIndex(EmbeddingIndex):
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {doc}")
log.exception(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
@ -73,7 +73,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
print(f"Initializing ChromaMemoryAdapter with url: {url}")
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
@ -88,12 +88,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def initialize(self) -> None:
try:
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e:
import traceback
traceback.print_exc()
log.exception("Could not connect to Chroma server")
raise RuntimeError("Could not connect to Chroma server") from e
async def shutdown(self) -> None:
@ -109,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.json()},
metadata={"bank": memory_bank.model_dump_json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
@ -123,10 +121,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(VectorMemoryBank, data)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse bank: {collection.metadata}")
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
@ -147,9 +142,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
@ -159,8 +152,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
self.cache[bank_id] = index
return index

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import List, Tuple
import psycopg2
@ -24,6 +25,8 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorConfig
log = logging.getLogger(__name__)
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -124,7 +127,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.cache = {}
async def initialize(self) -> None:
print(f"Initializing PGVector memory adapter with config: {self.config}")
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
try:
self.conn = psycopg2.connect(
host=self.config.host,
@ -138,7 +141,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
version = check_extension_version(self.cursor)
if version:
print(f"Vector extension version: {version}")
log.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
@ -151,9 +154,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
"""
)
except Exception as e:
import traceback
traceback.print_exc()
log.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
@ -201,10 +202,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
@ -213,8 +211,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = self.cache.get(bank_id, None)
if not index:
raise ValueError(f"Bank {bank_id} not found")
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank_id] = index
return index

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import traceback
import logging
import uuid
from typing import Any, Dict, List
@ -23,6 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
)
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
@ -90,7 +91,7 @@ class QdrantIndex(EmbeddingIndex):
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
traceback.print_exc()
log.exception("Failed to parse chunk")
continue
chunks.append(chunk)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List, Optional
@ -22,6 +23,8 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import WeaviateConfig, WeaviateRequestProviderData
log = logging.getLogger(__name__)
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
@ -69,10 +72,7 @@ class WeaviateIndex(EmbeddingIndex):
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {chunk_json}")
log.exception(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)

View file

@ -4,9 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from typing import Any, Dict
from pydantic import BaseModel, Field
class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}

View file

@ -4,24 +4,31 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
import threading
from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
_GLOBAL_STORAGE = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
}
_global_lock = threading.Lock()
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
self.resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
resource = Resource.create(
{
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)
# Set up tracing with Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=self.config.jaeger_host,
agent_port=self.config.jaeger_port,
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
trace_provider = TracerProvider(resource=self.resource)
trace_processor = BatchSpanProcessor(jaeger_exporter)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader]
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry):
self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span()
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=event.timestamp,
)
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type,
attributes={
"message": event.message,
"severity": event.severity.value,
**event.attributes,
},
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
self.meter.create_counter(
name=event.metric,
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
self.meter.create_gauge(
name=event.metric,
unit=event.unit,
description=f"Gauge for {event.metric}",
).set(event.value, attributes=event.attributes)
up_down_counter = self._get_or_create_up_down_counter(
event.metric, event.unit
)
up_down_counter.add(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = (
self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]:
return
parent_span = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
# Create a new trace context with the trace_id
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
# Set as current span using context manager
with trace.use_span(span, end_on_exit=False):
pass # Let the span continue beyond this block
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
)
)
elif isinstance(event.payload, SpanEndPayload):
span = trace.get_current_span()
span.set_status(
trace.Status(
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
span.set_status(status)
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
# Remove from active spans
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())
raise NotImplementedError("Trace retrieval not implemented yet")