forked from phoenix-oss/llama-stack-mirror
This reverts commit b8535417e0
.
Test plan:
LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run
~/.llama/distributions/together/together-run.yaml
python -m examples.agents.e2e_loop_with_client_tools localhost 8321
This commit is contained in:
parent
df4fbae35c
commit
60e7f3d705
5 changed files with 14 additions and 161 deletions
|
@ -285,7 +285,7 @@ class CompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponse(MetricResponseMixin):
|
class CompletionResponse(BaseModel):
|
||||||
"""Response from a completion request.
|
"""Response from a completion request.
|
||||||
|
|
||||||
:param content: The generated completion text
|
:param content: The generated completion text
|
||||||
|
@ -299,7 +299,7 @@ class CompletionResponse(MetricResponseMixin):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
class CompletionResponseStreamChunk(BaseModel):
|
||||||
"""A chunk of a streamed completion response.
|
"""A chunk of a streamed completion response.
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||||
|
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
|
||||||
"""A chunk of a streamed chat completion response.
|
"""A chunk of a streamed chat completion response.
|
||||||
|
|
||||||
:param event: The event containing the new content
|
:param event: The event containing the new content
|
||||||
|
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ChatCompletionResponse(MetricResponseMixin):
|
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
|
||||||
"""Response from a chat completion request.
|
"""Response from a chat completion request.
|
||||||
|
|
||||||
:param completion_message: The complete response message
|
:param completion_message: The complete response message
|
||||||
|
|
|
@ -163,9 +163,7 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
# Add telemetry as an optional dependency to all auto-routed providers
|
deps__=[info.routing_table_api.value],
|
||||||
optional_api_dependencies=[Api.telemetry],
|
|
||||||
deps__=([info.routing_table_api.value, Api.telemetry.value]),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||||
from .routers import (
|
from .routers import (
|
||||||
DatasetIORouter,
|
DatasetIORouter,
|
||||||
EvalRouter,
|
EvalRouter,
|
||||||
|
@ -65,17 +65,9 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict
|
||||||
"eval": EvalRouter,
|
"eval": EvalRouter,
|
||||||
"tool_runtime": ToolRuntimeRouter,
|
"tool_runtime": ToolRuntimeRouter,
|
||||||
}
|
}
|
||||||
api_to_deps = {
|
|
||||||
"inference": {"telemetry": Api.telemetry},
|
|
||||||
}
|
|
||||||
if api.value not in api_to_routers:
|
if api.value not in api_to_routers:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
api_to_dep_impl = {}
|
impl = api_to_routers[api.value](routing_table)
|
||||||
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
|
|
||||||
if dep_api in deps:
|
|
||||||
api_to_dep_impl[dep_name] = deps[dep_api]
|
|
||||||
|
|
||||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,8 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import time
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from llama_stack import logcat
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -22,10 +21,6 @@ from llama_stack.apis.eval import (
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionMessage,
|
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -33,14 +28,13 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
ScoreBatchResponse,
|
ScoreBatchResponse,
|
||||||
|
@ -49,7 +43,6 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.telemetry import MetricEvent, Telemetry
|
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
RAGDocument,
|
RAGDocument,
|
||||||
RAGQueryConfig,
|
RAGQueryConfig,
|
||||||
|
@ -59,10 +52,7 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
@ -131,14 +121,9 @@ class InferenceRouter(Inference):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
telemetry: Optional[Telemetry] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
logcat.debug("core", "Initializing InferenceRouter")
|
logcat.debug("core", "Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.telemetry = telemetry
|
|
||||||
if self.telemetry:
|
|
||||||
self.tokenizer = Tokenizer.get_instance()
|
|
||||||
self.formatter = ChatFormat(self.tokenizer)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logcat.debug("core", "InferenceRouter.initialize")
|
logcat.debug("core", "InferenceRouter.initialize")
|
||||||
|
@ -162,57 +147,6 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
def _construct_metrics(
|
|
||||||
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
|
|
||||||
) -> List[MetricEvent]:
|
|
||||||
span = get_current_span()
|
|
||||||
metrics = [
|
|
||||||
("prompt_tokens", prompt_tokens),
|
|
||||||
("completion_tokens", completion_tokens),
|
|
||||||
("total_tokens", total_tokens),
|
|
||||||
]
|
|
||||||
metric_events = []
|
|
||||||
for metric_name, value in metrics:
|
|
||||||
metric_events.append(
|
|
||||||
MetricEvent(
|
|
||||||
trace_id=span.trace_id,
|
|
||||||
span_id=span.span_id,
|
|
||||||
metric=metric_name,
|
|
||||||
value=value,
|
|
||||||
timestamp=time.time(),
|
|
||||||
unit="tokens",
|
|
||||||
attributes={
|
|
||||||
"model_id": model.model_id,
|
|
||||||
"provider_id": model.provider_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return metric_events
|
|
||||||
|
|
||||||
async def _compute_and_log_token_usage(
|
|
||||||
self,
|
|
||||||
prompt_tokens: int,
|
|
||||||
completion_tokens: int,
|
|
||||||
total_tokens: int,
|
|
||||||
model: Model,
|
|
||||||
) -> List[MetricEvent]:
|
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
|
||||||
if self.telemetry:
|
|
||||||
for metric in metrics:
|
|
||||||
await self.telemetry.log_event(metric)
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
async def _count_tokens(
|
|
||||||
self,
|
|
||||||
messages: List[Message] | InterleavedContent,
|
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
||||||
) -> Optional[int]:
|
|
||||||
if isinstance(messages, list):
|
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
|
||||||
else:
|
|
||||||
encoded = self.formatter.encode_content(messages)
|
|
||||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -225,7 +159,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> AsyncGenerator:
|
||||||
logcat.debug(
|
logcat.debug(
|
||||||
"core",
|
"core",
|
||||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
|
@ -276,47 +210,10 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
async def stream_generator():
|
|
||||||
completion_text = ""
|
|
||||||
async for chunk in await provider.chat_completion(**params):
|
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
|
||||||
if chunk.event.delta.type == "text":
|
|
||||||
completion_text += chunk.event.delta.text
|
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
completion_tokens = await self._count_tokens(
|
|
||||||
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return stream_generator()
|
|
||||||
else:
|
else:
|
||||||
response = await provider.chat_completion(**params)
|
return await provider.chat_completion(**params)
|
||||||
completion_tokens = await self._count_tokens(
|
|
||||||
[response.completion_message],
|
|
||||||
tool_config.tool_prompt_format,
|
|
||||||
)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -347,41 +244,10 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = await self._count_tokens(content)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
return (chunk async for chunk in await provider.completion(**params))
|
||||||
async def stream_generator():
|
|
||||||
completion_text = ""
|
|
||||||
async for chunk in await provider.completion(**params):
|
|
||||||
if hasattr(chunk, "delta"):
|
|
||||||
completion_text += chunk.delta
|
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return stream_generator()
|
|
||||||
else:
|
else:
|
||||||
response = await provider.completion(**params)
|
return await provider.completion(**params)
|
||||||
completion_tokens = await self._count_tokens(response.content)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
metrics = await self._compute_and_log_token_usage(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -73,7 +73,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = deps.get(Api.datasetio)
|
self.datasetio_api = deps.get(Api.datasetio)
|
||||||
self.meter = None
|
|
||||||
|
|
||||||
resource = Resource.create(
|
resource = Resource.create(
|
||||||
{
|
{
|
||||||
|
@ -172,8 +171,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||||
return _GLOBAL_STORAGE["gauges"][name]
|
return _GLOBAL_STORAGE["gauges"][name]
|
||||||
|
|
||||||
def _log_metric(self, event: MetricEvent) -> None:
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
if self.meter is None:
|
|
||||||
return
|
|
||||||
if isinstance(event.value, int):
|
if isinstance(event.value, int):
|
||||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
counter.add(event.value, attributes=event.attributes)
|
counter.add(event.value, attributes=event.attributes)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue