mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
Merge branch 'main' into remove-batch-inference
This commit is contained in:
commit
32b87bf88a
748 changed files with 127607 additions and 50032 deletions
|
@ -61,7 +61,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
@ -88,6 +88,11 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
if self.store:
|
||||
try:
|
||||
await self.store.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -158,7 +163,7 @@ class InferenceRouter(Inference):
|
|||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
|
@ -391,7 +396,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
|
@ -487,7 +492,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -497,7 +502,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
|
@ -624,7 +629,7 @@ class InferenceRouter(Inference):
|
|||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
|
@ -670,7 +675,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
@ -715,7 +720,7 @@ class InferenceRouter(Inference):
|
|||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"finish_reason": "stop",
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
@ -766,7 +771,7 @@ class InferenceRouter(Inference):
|
|||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
|
@ -815,4 +820,4 @@ class InferenceRouter(Inference):
|
|||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue