Separate chat_completion stream and non-stream implementations

This is a pretty important requirement. The streaming response type is
an AsyncGenerator while the non-stream one is a single object. So far
this has worked _sometimes_ due to various pre-existing hacks (and in
some cases, just failed.)
This commit is contained in:
Ashwin Bharambe 2024-10-08 10:52:16 -07:00 committed by Ashwin Bharambe
parent f8752ab8dc
commit 0c9eb3341c
5 changed files with 346 additions and 287 deletions

View file

@ -70,7 +70,7 @@ class InferenceRouter(Inference):
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -91,27 +91,32 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(model).completion(
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
params = dict(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
else:
return provider.completion(**params)
async def embeddings(
self,