mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:42:44 +00:00
Merge conflicts
This commit is contained in:
commit
5b027d2de5
198 changed files with 6140 additions and 3477 deletions
5
llama_stack/providers/remote/datasetio/__init__.py
Normal file
5
llama_stack/providers/remote/datasetio/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -3,12 +3,13 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HuggingfaceDatasetIOConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from llama_stack.apis.datasetio import * # noqa: F403
|
|||
|
||||
|
||||
import datasets as hf_datasets
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
|
|||
|
|
@ -4,11 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from ._config import NVIDIAConfig
|
||||
from ._nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import NVIDIAConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter:
|
||||
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||
from .nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = NVIDIAInferenceAdapter(config)
|
||||
|
|
|
|||
50
llama_stack/providers/remote/inference/nvidia/config.py
Normal file
50
llama_stack/providers/remote/inference/nvidia/config.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
|
||||
|
||||
By default the configuration is set to use the hosted APIs. This requires
|
||||
an API key which can be obtained from https://ngc.nvidia.com/.
|
||||
|
||||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||
variable to set the api_key. Please do not put your API key in code.
|
||||
|
||||
If you are using a self-hosted NVIDIA NIM, you can set the url to the
|
||||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||
"""
|
||||
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv(
|
||||
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
|
||||
),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
183
llama_stack/providers/remote/inference/nvidia/nvidia.py
Normal file
183
llama_stack/providers/remote/inference/nvidia/nvidia.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
from llama_models.datatypes import SamplingParams
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.sku_list import CoreModelId
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
ResponseFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from .utils import _is_nvidia_hosted, check_health
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
|
||||
print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
if not config.api_key:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. "
|
||||
"Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
# elif self._config.api_key:
|
||||
#
|
||||
# we don't raise this warning because a user may have deployed their
|
||||
# self-hosted NIM with an API key requirement.
|
||||
#
|
||||
# warnings.warn(
|
||||
# "API key is not required for self-hosted NVIDIA NIM. "
|
||||
# "Consider removing the api_key from the configuration."
|
||||
# )
|
||||
|
||||
self._config = config
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=self._config.api_key or "NO KEY",
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[
|
||||
ToolPromptFormat
|
||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
request = convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
|
||||
) from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
581
llama_stack/providers/remote/inference/nvidia/openai_utils.py
Normal file
581
llama_stack/providers/remote/inference/nvidia/openai_utils.py
Normal file
|
|
@ -0,0 +1,581 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
)
|
||||
from openai import AsyncStream
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||
)
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as OpenAIFunction,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
||||
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||
"""
|
||||
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
|
||||
|
||||
ToolDefinition:
|
||||
tool_name: str | BuiltinTool
|
||||
description: Optional[str]
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
|
||||
ToolParamDefinition:
|
||||
param_type: str
|
||||
description: Optional[str]
|
||||
required: Optional[bool]
|
||||
default: Optional[Any]
|
||||
|
||||
|
||||
OpenAI spec -
|
||||
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
param_name: {
|
||||
"type": param_type,
|
||||
"description": description,
|
||||
"default": default,
|
||||
},
|
||||
...
|
||||
},
|
||||
"required": [param_name, ...],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
out = {
|
||||
"type": "function",
|
||||
"function": {},
|
||||
}
|
||||
function = out["function"]
|
||||
|
||||
if isinstance(tool.tool_name, BuiltinTool):
|
||||
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
|
||||
else:
|
||||
function.update(name=tool.tool_name)
|
||||
|
||||
if tool.description:
|
||||
function.update(description=tool.description)
|
||||
|
||||
if tool.parameters:
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
properties = parameters["properties"]
|
||||
required = []
|
||||
for param_name, param in tool.parameters.items():
|
||||
properties[param_name] = {"type": param.param_type}
|
||||
if param.description:
|
||||
properties[param_name].update(description=param.description)
|
||||
if param.default:
|
||||
properties[param_name].update(default=param.default)
|
||||
if param.required:
|
||||
required.append(param_name)
|
||||
|
||||
if required:
|
||||
parameters.update(required=required)
|
||||
|
||||
function.update(parameters=parameters)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||
"""
|
||||
Convert a Message to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# users can supply a dict instead of a Message object, we'll
|
||||
# convert it to a Message object and proceed with some type safety.
|
||||
if isinstance(message, dict):
|
||||
if "role" not in message:
|
||||
raise ValueError("role is required in message")
|
||||
if message["role"] == "user":
|
||||
message = UserMessage(**message)
|
||||
elif message["role"] == "assistant":
|
||||
message = CompletionMessage(**message)
|
||||
elif message["role"] == "ipython":
|
||||
message = ToolResponseMessage(**message)
|
||||
elif message["role"] == "system":
|
||||
message = SystemMessage(**message)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message role: {message['role']}")
|
||||
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
content=message.content, # TODO(mf): handle image content
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=message.content,
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=tool.tool_name,
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
],
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=message.content,
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=message.content,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# messages -> messages
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# response_format -> GrammarResponseFormat TODO(mf)
|
||||
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
||||
# tools -> tools
|
||||
# tool_choice ("auto", "required") -> tool_choice
|
||||
# tool_prompt_format -> TBD
|
||||
# stream -> stream
|
||||
# logprobs -> logprobs
|
||||
|
||||
if request.response_format and not isinstance(
|
||||
request.response_format, JsonSchemaResponseFormat
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {request.response_format}. "
|
||||
"Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[_convert_message(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
n=n,
|
||||
extra_body=dict(nvext=nvext),
|
||||
extra_headers={
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
},
|
||||
)
|
||||
|
||||
if request.response_format:
|
||||
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
||||
# payload.update(response_format="json_object")
|
||||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.tools:
|
||||
payload.update(
|
||||
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
)
|
||||
if request.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_choice.value
|
||||
) # we cannot include tool_choice w/o tools, server will complain
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=True)
|
||||
payload.update(top_logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if (
|
||||
request.sampling_params.top_k != -1
|
||||
and request.sampling_params.top_k < 1
|
||||
):
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
||||
finish_reason: Literal["stop", "length", "tool_calls", ...]
|
||||
- stop: model hit a natural stop point or a provided stop sequence
|
||||
- length: maximum number of tokens specified in the request was reached
|
||||
- tool_calls: model called a tool
|
||||
|
||||
->
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# TODO(mf): are end_of_turn and end_of_message semantics correct?
|
||||
return {
|
||||
"stop": StopReason.end_of_turn,
|
||||
"length": StopReason.out_of_tokens,
|
||||
"tool_calls": StopReason.end_of_message,
|
||||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
"""
|
||||
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||
|
||||
OpenAI ChatCompletionMessageToolCall:
|
||||
id: str
|
||||
function: Function
|
||||
type: Literal["function"]
|
||||
|
||||
OpenAI Function:
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
->
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: Dict[str, ...]
|
||||
"""
|
||||
if not tool_calls:
|
||||
return [] # CompletionMessage tool_calls is not optional
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _convert_openai_logprobs(
|
||||
logprobs: OpenAIChoiceLogprobs,
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
|
||||
|
||||
OpenAI ChoiceLogprobs:
|
||||
content: Optional[List[ChatCompletionTokenLogprob]]
|
||||
|
||||
OpenAI ChatCompletionTokenLogprob:
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
OpenAI TopLogprob:
|
||||
token: str
|
||||
logprob: float
|
||||
|
||||
->
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
||||
return [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
|
||||
}
|
||||
)
|
||||
for content in logprobs.content
|
||||
]
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Choice into a ChatCompletionResponse.
|
||||
|
||||
OpenAI Choice:
|
||||
message: ChatCompletionMessage
|
||||
finish_reason: str
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChatCompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponse:
|
||||
completion_message: CompletionMessage
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
|
||||
CompletionMessage:
|
||||
role: Literal["assistant"]
|
||||
content: str | ImageMedia | List[str | ImageMedia]
|
||||
stop_reason: StopReason
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
assert (
|
||||
hasattr(choice, "message") and choice.message
|
||||
), "error in server response: message not found"
|
||||
assert (
|
||||
hasattr(choice, "finish_reason") and choice.finish_reason
|
||||
), "error in server response: finish_reason not found"
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content
|
||||
or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
||||
async def convert_openai_chat_completion_stream(
|
||||
stream: AsyncStream[OpenAIChatCompletionChunk],
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
"""
|
||||
Convert a stream of OpenAI chat completion chunks into a stream
|
||||
of ChatCompletionResponseStreamChunk.
|
||||
|
||||
OpenAI ChatCompletionChunk:
|
||||
choices: List[Choice]
|
||||
|
||||
OpenAI Choice: # different from the non-streamed Choice
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
||||
logprobs: Optional[ChoiceLogprobs]
|
||||
|
||||
OpenAI ChoiceDelta:
|
||||
content: Optional[str]
|
||||
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCall:
|
||||
index: int
|
||||
id: Optional[str]
|
||||
function: Optional[ChoiceDeltaToolCallFunction]
|
||||
type: Optional[Literal["function"]]
|
||||
|
||||
OpenAI ChoiceDeltaToolCallFunction:
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
|
||||
->
|
||||
|
||||
ChatCompletionResponseStreamChunk:
|
||||
event: ChatCompletionResponseEvent
|
||||
|
||||
ChatCompletionResponseEvent:
|
||||
event_type: ChatCompletionResponseEventType
|
||||
delta: Union[str, ToolCallDelta]
|
||||
logprobs: Optional[List[TokenLogProbs]]
|
||||
stop_reason: Optional[StopReason]
|
||||
|
||||
ChatCompletionResponseEventType:
|
||||
start = "start"
|
||||
progress = "progress"
|
||||
complete = "complete"
|
||||
|
||||
ToolCallDelta:
|
||||
content: Union[str, ToolCall]
|
||||
parse_status: ToolCallParseStatus
|
||||
|
||||
ToolCall:
|
||||
call_id: str
|
||||
tool_name: str
|
||||
arguments: str
|
||||
|
||||
ToolCallParseStatus:
|
||||
started = "started"
|
||||
in_progress = "in_progress"
|
||||
failure = "failure"
|
||||
success = "success"
|
||||
|
||||
TokenLogProbs:
|
||||
logprobs_by_token: Dict[str, float]
|
||||
- token, logprob
|
||||
|
||||
StopReason:
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
"""
|
||||
|
||||
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||
def _event_type_generator() -> (
|
||||
Generator[ChatCompletionResponseEventType, None, None]
|
||||
):
|
||||
yield ChatCompletionResponseEventType.start
|
||||
while True:
|
||||
yield ChatCompletionResponseEventType.progress
|
||||
|
||||
event_type = _event_type_generator()
|
||||
|
||||
# we implement NIM specific semantics, the main difference from OpenAI
|
||||
# is that tool_calls are always produced as a complete call. there is no
|
||||
# intermediate / partial tool call streamed. because of this, we can
|
||||
# simplify the logic and not concern outselves with parse_status of
|
||||
# started/in_progress/failed. we can always assume success.
|
||||
#
|
||||
# a stream of ChatCompletionResponseStreamChunk consists of
|
||||
# 0. a start event
|
||||
# 1. zero or more progress events
|
||||
# - each progress event has a delta
|
||||
# - each progress event may have a stop_reason
|
||||
# - each progress event may have logprobs
|
||||
# - each progress event may have tool_calls
|
||||
# if a progress event has tool_calls,
|
||||
# it is fully formed and
|
||||
# can be emitted with a parse_status of success
|
||||
# 2. a complete event
|
||||
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
# if tool call and content, emit both separately
|
||||
|
||||
if choice.delta.tool_calls:
|
||||
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||
# does not support both, so we emit the content first
|
||||
if choice.delta.content:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content,
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
# it is possible to have parallel tool calls in stream, but
|
||||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn(
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest"
|
||||
)
|
||||
|
||||
# NIM only produces fully formed tool calls, so we can assume success
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=ToolCallDelta(
|
||||
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=next(event_type),
|
||||
delta=choice.delta.content or "", # content is not optional
|
||||
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
54
llama_stack/providers/remote/inference/nvidia/utils.py
Normal file
54
llama_stack/providers/remote/inference/nvidia/utils.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
return "integrate.api.nvidia.com" in config.url
|
||||
|
||||
|
||||
async def _get_health(url: str) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
||||
|
||||
Args:
|
||||
url (str): URL of the server
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool]: (is_live, is_ready)
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
live = await client.get(f"{url}/v1/health/live")
|
||||
ready = await client.get(f"{url}/v1/health/ready")
|
||||
return live.status_code == 200, ready.status_code == 200
|
||||
|
||||
|
||||
async def check_health(config: NVIDIAConfig) -> None:
|
||||
"""
|
||||
Check if the server is running and ready
|
||||
|
||||
Args:
|
||||
url (str): URL of the server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the server is not running or ready
|
||||
"""
|
||||
if not _is_nvidia_hosted(config):
|
||||
print("Checking NVIDIA NIM health...")
|
||||
try:
|
||||
is_live, is_ready = await _get_health(config.url)
|
||||
if not is_live:
|
||||
raise ConnectionError("NVIDIA NIM is not running")
|
||||
if not is_ready:
|
||||
raise ConnectionError("NVIDIA NIM is not ready")
|
||||
# TODO(mf): should we wait for the server to be ready?
|
||||
except httpx.ConnectError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
|
@ -39,6 +40,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
request_has_media,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
|
|
@ -57,18 +59,26 @@ model_aliases = [
|
|||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
|
|
@ -81,6 +91,14 @@ model_aliases = [
|
|||
"llama3.2-vision",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias_with_just_provider_model_id(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_model_alias(
|
||||
|
|
@ -105,7 +123,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
log.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
except httpx.ConnectError as e:
|
||||
|
|
|
|||
|
|
@ -37,6 +37,18 @@ class InferenceEndpointImplConfig(BaseModel):
|
|||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"endpoint_name": endpoint_name,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
|
|
@ -47,3 +59,15 @@ class InferenceAPIImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
repo: str = "${env.INFERENCE_MODEL}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"huggingface_repo": repo,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
|
@ -34,7 +38,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_model_aliases():
|
||||
return [
|
||||
build_model_alias(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
]
|
||||
|
||||
|
||||
class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||
|
|
@ -44,45 +59,39 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
def __init__(self) -> None:
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor()
|
||||
for model in all_registered_models()
|
||||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[Model]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
Model(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
"huggingface_repo": repo,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
model = await self.register_helper.register_model(model)
|
||||
if model.provider_resource_id != self.model_id:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
|
||||
)
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -176,7 +185,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
|
|
@ -186,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
|
|
@ -241,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
|
|
@ -256,7 +266,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -264,7 +274,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
print(f"Initializing TGI client with url={config.url}")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
|
@ -34,6 +36,9 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_model_aliases():
|
||||
return [
|
||||
build_model_alias(
|
||||
|
|
@ -53,7 +58,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -21,6 +22,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaIndex(EmbeddingIndex):
|
||||
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||
|
|
@ -56,10 +59,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {doc}")
|
||||
log.exception(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
|
|
@ -73,7 +73,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
print(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
|
|
@ -88,12 +88,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
print(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
log.exception("Could not connect to Chroma server")
|
||||
raise RuntimeError("Could not connect to Chroma server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -109,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.json()},
|
||||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
|
|
@ -123,10 +121,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(VectorMemoryBank, data)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse bank: {collection.metadata}")
|
||||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
index = BankWithIndex(
|
||||
|
|
@ -147,9 +142,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
|
||||
await index.insert_documents(documents)
|
||||
|
||||
|
|
@ -159,8 +152,20 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
||||
collection = await self.client.get_collection(bank_id)
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import psycopg2
|
||||
|
|
@ -24,6 +25,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
|
|
@ -124,7 +127,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
|
|
@ -138,7 +141,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
version = check_extension_version(self.cursor)
|
||||
if version:
|
||||
print(f"Vector extension version: {version}")
|
||||
log.info(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
|
|
@ -151,9 +154,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
log.exception("Could not connect to PGVector database server")
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
@ -201,10 +202,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
await index.insert_documents(documents)
|
||||
|
||||
async def query_documents(
|
||||
|
|
@ -213,8 +211,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id, None)
|
||||
if not index:
|
||||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import traceback
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
|
@ -23,6 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk(**point.payload["chunk_content"])
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
log.exception("Failed to parse chunk")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -22,6 +23,8 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import WeaviateConfig, WeaviateRequestProviderData
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
|
|
@ -69,10 +72,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"Failed to parse document: {chunk_json}")
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OpenTelemetryConfig(BaseModel):
|
||||
jaeger_host: str = "localhost"
|
||||
jaeger_port: int = 6831
|
||||
otel_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
)
|
||||
service_name: str = Field(
|
||||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,24 +4,31 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
import threading
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import (
|
||||
ConsoleMetricExporter,
|
||||
PeriodicExportingMetricReader,
|
||||
)
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from .config import OpenTelemetryConfig
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
|
|
@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
def __init__(self, config: OpenTelemetryConfig):
|
||||
self.config = config
|
||||
|
||||
self.resource = Resource.create(
|
||||
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
# Set up tracing with Jaeger exporter
|
||||
jaeger_exporter = JaegerExporter(
|
||||
agent_host_name=self.config.jaeger_host,
|
||||
agent_port=self.config.jaeger_port,
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
trace_provider = TracerProvider(resource=self.resource)
|
||||
trace_processor = BatchSpanProcessor(jaeger_exporter)
|
||||
trace_provider.add_span_processor(trace_processor)
|
||||
trace.set_tracer_provider(trace_provider)
|
||||
self.tracer = trace.get_tracer(__name__)
|
||||
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
# Set up metrics
|
||||
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=self.resource, metric_readers=[metric_reader]
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
|
|
@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
self._log_structured(event)
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
|
||||
span = trace.get_current_span()
|
||||
span.add_event(
|
||||
name=event.message,
|
||||
attributes={"severity": event.severity.value, **event.attributes},
|
||||
timestamp=event.timestamp,
|
||||
)
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
**event.attributes,
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
|
||||
)
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if isinstance(event.value, int):
|
||||
self.meter.create_counter(
|
||||
name=event.metric,
|
||||
unit=event.unit,
|
||||
description=f"Counter for {event.metric}",
|
||||
).add(event.value, attributes=event.attributes)
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
self.meter.create_gauge(
|
||||
name=event.metric,
|
||||
unit=event.unit,
|
||||
description=f"Gauge for {event.metric}",
|
||||
).set(event.value, attributes=event.attributes)
|
||||
up_down_counter = self._get_or_create_up_down_counter(
|
||||
event.metric, event.unit
|
||||
)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(
|
||||
self, name: str, unit: str
|
||||
) -> metrics.UpDownCounter:
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = (
|
||||
self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent) -> None:
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
context = trace.set_span_in_context(
|
||||
trace.NonRecordingSpan(
|
||||
trace.SpanContext(
|
||||
trace_id=string_to_trace_id(event.trace_id),
|
||||
span_id=string_to_span_id(event.span_id),
|
||||
is_remote=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
span = self.tracer.start_span(
|
||||
name=event.payload.name,
|
||||
kind=trace.SpanKind.INTERNAL,
|
||||
context=context,
|
||||
attributes=event.attributes,
|
||||
)
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
if event.payload.parent_span_id:
|
||||
span.set_parent(
|
||||
trace.SpanContext(
|
||||
trace_id=string_to_trace_id(event.trace_id),
|
||||
span_id=string_to_span_id(event.payload.parent_span_id),
|
||||
is_remote=True,
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
# Create a new trace context with the trace_id
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
start_time=int(event.timestamp.timestamp() * 1e9),
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
# Set as current span using context manager
|
||||
with trace.use_span(span, end_on_exit=False):
|
||||
pass # Let the span continue beyond this block
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
)
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = trace.get_current_span()
|
||||
span.set_status(
|
||||
trace.Status(
|
||||
trace.StatusCode.OK
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.StatusCode.ERROR
|
||||
)
|
||||
)
|
||||
span.end(end_time=event.timestamp)
|
||||
span.set_status(status)
|
||||
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
|
||||
|
||||
# Remove from active spans
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
# we need to look up the root span id
|
||||
raise NotImplementedError("not yet no")
|
||||
|
||||
|
||||
# Usage example
|
||||
async def main():
|
||||
telemetry = OpenTelemetryTelemetry("my-service")
|
||||
await telemetry.initialize()
|
||||
|
||||
# Log an unstructured event
|
||||
await telemetry.log_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span456",
|
||||
timestamp=datetime.now(),
|
||||
message="This is a log message",
|
||||
severity=LogSeverity.INFO,
|
||||
)
|
||||
)
|
||||
|
||||
# Log a metric event
|
||||
await telemetry.log_event(
|
||||
MetricEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span456",
|
||||
timestamp=datetime.now(),
|
||||
metric="my_metric",
|
||||
value=42,
|
||||
unit="count",
|
||||
)
|
||||
)
|
||||
|
||||
# Log a structured event (span start)
|
||||
await telemetry.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span789",
|
||||
timestamp=datetime.now(),
|
||||
payload=SpanStartPayload(name="my_operation"),
|
||||
)
|
||||
)
|
||||
|
||||
# Log a structured event (span end)
|
||||
await telemetry.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id="trace123",
|
||||
span_id="span789",
|
||||
timestamp=datetime.now(),
|
||||
payload=SpanEndPayload(status=SpanStatus.OK),
|
||||
)
|
||||
)
|
||||
|
||||
await telemetry.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
raise NotImplementedError("Trace retrieval not implemented yet")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue