Fix conversion to RawMessage everywhere

This commit is contained in:
Ashwin Bharambe 2024-12-17 13:38:01 -08:00
parent fbca51d6da
commit b7a7caa9a8
11 changed files with 87 additions and 78 deletions

View file

@ -25,7 +25,6 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.datatypes import RawContent, RawMessage
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -39,6 +38,10 @@ from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import (
Fp8QuantizationConfig,
@ -50,14 +53,6 @@ from .config import (
log = logging.getLogger(__name__)
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
messages: List[RawMessage]
class CompletionRequestWithRawContent(CompletionRequest):
content: RawContent
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))

View file

@ -12,7 +12,6 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import Model
from llama_models.llama3.api.datatypes import (
RawMessage,
SamplingParams,
StopReason,
ToolDefinition,
@ -53,14 +52,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
interleaved_content_convert_to_raw,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generation import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
Llama,
)
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
@ -450,20 +445,3 @@ class MetaReferenceInferenceImpl(
else:
for x in impl():
yield x
async def convert_request_to_raw(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
if isinstance(request, ChatCompletionRequest):
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
request.messages = messages
else:
request.content = await interleaved_content_convert_to_raw(request.content)
return request

View file

@ -120,15 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model_id,
messages=messages,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
raise NotImplementedError("Completion not implemented for vLLM")
async def chat_completion(
self,
@ -142,8 +134,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
assert self.engine is not None
request = ChatCompletionRequest(
@ -160,7 +150,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
prompt = await chat_completion_request_to_prompt(request, self.formatter)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id