mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
fix non-streaming api in inference server
This commit is contained in:
parent
404af06e02
commit
cc98fbb058
3 changed files with 32 additions and 26 deletions
|
@ -14,6 +14,7 @@ from termcolor import cprint
|
||||||
|
|
||||||
from .api import (
|
from .api import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -50,35 +51,33 @@ class InferenceClient(Inference):
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
if request.stream:
|
||||||
|
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponse(**json.loads(data))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
client = InferenceClient(f"http://{host}:{port}")
|
client = InferenceClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
message = UserMessage(content="hello world, help me out here")
|
message = UserMessage(content="hello world, help me out here")
|
||||||
cprint(f"User>{message.content}", "green")
|
cprint(f"User>{message.content}", "green")
|
||||||
req = ChatCompletionRequest(
|
|
||||||
model=InstructModel.llama3_70b_chat,
|
|
||||||
messages=[message],
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
iterator = client.chat_completion(
|
iterator = client.chat_completion(
|
||||||
ChatCompletionRequest(
|
ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=InstructModel.llama3_8b_chat,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=True,
|
stream=stream,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async for log in EventLogger().log(iterator):
|
async for log in EventLogger().log(iterator):
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int, stream: bool = True):
|
||||||
asyncio.run(run_main(host, port))
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -6,7 +6,10 @@
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.inference.api import ChatCompletionResponseEventType
|
from llama_toolchain.inference.api import (
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogEvent:
|
class LogEvent:
|
||||||
|
@ -25,12 +28,16 @@ class LogEvent:
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
class EventLogger:
|
||||||
async def log(self, event_generator, stream=True):
|
async def log(self, event_generator):
|
||||||
async for chunk in event_generator:
|
async for chunk in event_generator:
|
||||||
event = chunk.event
|
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
||||||
if event.event_type == ChatCompletionResponseEventType.start:
|
event = chunk.event
|
||||||
|
if event.event_type == ChatCompletionResponseEventType.start:
|
||||||
|
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||||
|
elif event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
yield LogEvent(event.delta, color="yellow", end="")
|
||||||
|
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||||
|
yield LogEvent("")
|
||||||
|
else:
|
||||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
yield LogEvent(chunk.completion_message.content, color="yellow")
|
||||||
yield LogEvent(event.delta, color="yellow", end="")
|
|
||||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
yield LogEvent("")
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ from .api.datatypes import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from .api.endpoints import (
|
from .api.endpoints import (
|
||||||
|
ChatCompletionResponse,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
@ -40,12 +41,13 @@ class InferenceImpl(Inference):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
if request.stream:
|
||||||
event=ChatCompletionResponseEvent(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event=ChatCompletionResponseEvent(
|
||||||
delta="",
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
logprobs = []
|
logprobs = []
|
||||||
|
@ -152,8 +154,6 @@ class InferenceImpl(Inference):
|
||||||
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
# TODO(ashwin): what else do we need to send out here when everything finishes?
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
content=message.content,
|
completion_message=message,
|
||||||
tool_calls=message.tool_calls,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue