# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-05-21 21:59:34 -07:00
parent 85b5f3172b
commit a1b356a2a2
71 changed files with 1111 additions and 10 deletions

View file

@ -32,8 +32,11 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAICompletionWithInputMessages,
Order,
ResponseFormat,
SamplingParams,
StopReason,
@ -73,6 +76,8 @@ from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
@ -141,10 +146,12 @@ class InferenceRouter(Inference):
self,
routing_table: RoutingTable,
telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
self.store = store
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
@ -607,9 +614,31 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_chat_completion(**params)
response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
return response_stream
else:
return await self._nonstream_openai_chat_completion(provider, params)
response = await self._nonstream_openai_chat_completion(provider, params)
if self.store:
await self.store.store_chat_completion(response, messages)
return response
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)