mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Separate chat_completion stream and non-stream implementations
This is a pretty important requirement. The streaming response type is an AsyncGenerator while the non-stream one is a single object. So far this has worked _sometimes_ due to various pre-existing hacks (and in some cases, just failed.)
This commit is contained in:
parent
f8752ab8dc
commit
0c9eb3341c
5 changed files with 346 additions and 287 deletions
|
@ -42,10 +42,10 @@ class InferenceClient(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -66,8 +66,30 @@ class InferenceClient(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
yield ChatCompletionResponse(**j)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
|
@ -91,23 +113,10 @@ class InferenceClient(Inference):
|
|||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
else:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/inference/chat_completion",
|
||||
json=encodable_dict(request),
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
yield ChatCompletionResponse(**j)
|
||||
|
||||
|
||||
async def run_main(
|
||||
|
|
|
@ -180,8 +180,10 @@ class ModelStore(Protocol):
|
|||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/completion")
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -190,8 +192,10 @@ class Inference(Protocol):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
|
@ -70,7 +70,7 @@ class InferenceRouter(Inference):
|
|||
async def register_model(self, model: ModelDef) -> None:
|
||||
await self.routing_table.register_model(model)
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -91,27 +91,32 @@ class InferenceRouter(Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
**params
|
||||
):
|
||||
yield chunk
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.chat_completion(**params))
|
||||
else:
|
||||
return provider.chat_completion(**params)
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
) -> AsyncGenerator:
|
||||
provider = self.routing_table.get_provider_impl(model)
|
||||
params = dict(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return (chunk async for chunk in provider.completion(**params))
|
||||
else:
|
||||
return provider.completion(**params)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
|
@ -55,7 +55,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
@ -79,7 +79,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -90,24 +90,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
ollama_model = self.map_to_provider_model(model)
|
||||
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
|
@ -123,16 +106,43 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
|
||||
common_params = {
|
||||
"model": ollama_model,
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
"options": options,
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.generate(**common_params)
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
|
@ -143,18 +153,24 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||
r["response"], stop_reason
|
||||
)
|
||||
yield ChatCompletionResponse(
|
||||
return ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = await self.client.generate(**params)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
stream = await self.client.generate(**common_params)
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
|
|
@ -46,9 +46,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
@ -59,6 +57,9 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -71,7 +72,6 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = augment_messages_for_tools(request)
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
|
@ -87,6 +87,57 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={token_result.text: token_result.logprobs[0]}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
@ -96,10 +147,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
||||
stop_reason = None
|
||||
|
||||
buffer = ""
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
|
@ -110,12 +158,10 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
buffer += token_result.text
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -125,27 +171,6 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
),
|
||||
)
|
||||
)
|
||||
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
if not request.stream:
|
||||
if request.logprobs:
|
||||
assert (
|
||||
len(token_result.logprobs) == 1
|
||||
), "Expected logprob to contain 1 result for the current token"
|
||||
assert (
|
||||
request.logprobs.top_k == 1
|
||||
), "Only top_k=1 is supported for LogProbConfig"
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
@ -166,23 +191,30 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
delta = text
|
||||
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
@ -215,10 +247,3 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||
else:
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue