update vllm; not quite tested yet

This commit is contained in:
Ashwin Bharambe 2024-10-08 13:38:32 -07:00 committed by Ashwin Bharambe
parent ed899a5dec
commit 336cf7a674

View file

@ -10,39 +10,26 @@ import uuid
from typing import Any
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
CompletionMessage,
InterleavedTextMedia,
Message,
StopReason,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from llama_stack.apis.inference import ChatCompletionRequest, Inference
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
LogProbConfig,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
chat_completion_request_to_prompt,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from .config import VLLMConfig
@ -72,10 +59,10 @@ def _vllm_sampling_params(sampling_params: Any) -> SamplingParams:
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return SamplingParams().from_optional(**kwargs)
return SamplingParams(**kwargs)
class VLLMInferenceImpl(Inference, ModelRegistryHelper):
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
"""Inference implementation for vLLM."""
HF_MODEL_MAPPINGS = {
@ -148,7 +135,7 @@ class VLLMInferenceImpl(Inference, ModelRegistryHelper):
if self.engine:
self.engine.shutdown_background_loop()
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -157,17 +144,16 @@ class VLLMInferenceImpl(Inference, ModelRegistryHelper):
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [Message(role="user", content=content)]
async for result in self.chat_completion(
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
):
yield result
)
async def chat_completion(
def chat_completion(
self,
model: str,
messages: list[Message],
@ -194,159 +180,59 @@ class VLLMInferenceImpl(Inference, ModelRegistryHelper):
)
log.info("Sampling params: %s", sampling_params)
vllm_sampling_params = _vllm_sampling_params(sampling_params)
messages = augment_messages_for_tools(request)
log.info("Augmented messages: %s", messages)
prompt = "".join([str(message.content) for message in messages])
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
if not stream:
# Non-streaming case
final_output = None
stop_reason = None
async for request_output in results_generator:
final_output = request_output
if stop_reason is None and request_output.outputs:
reason = request_output.outputs[-1].stop_reason
if reason == "stop":
stop_reason = StopReason.end_of_turn
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if not stop_reason:
stop_reason = StopReason.end_of_message
if final_output:
response = "".join([output.text for output in final_output.outputs])
yield ChatCompletionResponse(
completion_message=CompletionMessage(
content=response,
stop_reason=stop_reason,
),
logprobs=None,
)
if stream:
return self._stream_chat_completion(request, results_generator)
else:
# Streaming case
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
return self._nonstream_chat_completion(request, results_generator)
buffer = ""
last_chunk = ""
ipython = False
stop_reason = None
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> ChatCompletionResponse:
outputs = [o async for o in results_generator]
final_output = outputs[-1]
assert final_output is not None
outputs = final_output.outputs
finish_reason = outputs[-1].stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=finish_reason,
text="".join([output.text for output in outputs]),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
async def _generate_and_convert_to_openai_compat():
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
if chunk.outputs[-1].stop_reason:
reason = chunk.outputs[-1].stop_reason
if stop_reason is None and reason == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = "".join([output.text for output in chunk.outputs])
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
last_chunk_len = len(last_chunk)
last_chunk = text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text[last_chunk_len:],
stop_reason=stop_reason,
)
)
if not stop_reason:
stop_reason = StopReason.end_of_message
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
choice = OpenAICompatCompletionChoice(
finish_reason=chunk.outputs[-1].stop_reason,
text=text,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
async def embeddings(
self, model: str, contents: list[InterleavedTextMedia]