mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Separate chat_completion stream and non-stream implementations
This is a pretty important requirement. The streaming response type is an AsyncGenerator while the non-stream one is a single object. So far this has worked _sometimes_ due to various pre-existing hacks (and in some cases, just failed.)
This commit is contained in:
parent
f8752ab8dc
commit
0c9eb3341c
5 changed files with 346 additions and 287 deletions
|
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
|||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -91,27 +91,32 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.chat_completion(**params))
|
||||
else:
|
||||
return provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
params = dict(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.completion(**params))
|
||||
else:
|
||||
return provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue