mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Fix conversion to RawMessage everywhere
This commit is contained in:
parent
fbca51d6da
commit
b7a7caa9a8
11 changed files with 87 additions and 78 deletions
|
@ -25,7 +25,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.datatypes import RawContent, RawMessage
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
|
@ -39,6 +38,10 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
|
@ -50,14 +53,6 @@ from .config import (
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||
messages: List[RawMessage]
|
||||
|
||||
|
||||
class CompletionRequestWithRawContent(CompletionRequest):
|
||||
content: RawContent
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from llama_models.datatypes import Model
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
RawMessage,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
|
@ -53,14 +52,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
interleaved_content_convert_to_raw,
|
||||
convert_request_to_raw,
|
||||
)
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
Llama,
|
||||
)
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -450,20 +445,3 @@ class MetaReferenceInferenceImpl(
|
|||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
|
||||
async def convert_request_to_raw(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
messages = []
|
||||
for m in request.messages:
|
||||
content = await interleaved_content_convert_to_raw(m.content)
|
||||
d = m.model_dump()
|
||||
d["content"] = content
|
||||
messages.append(RawMessage(**d))
|
||||
request.messages = messages
|
||||
else:
|
||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||
|
||||
return request
|
||||
|
|
|
@ -120,15 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
log.info("vLLM completion")
|
||||
messages = [UserMessage(content=content)]
|
||||
return self.chat_completion(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
raise NotImplementedError("Completion not implemented for vLLM")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -142,8 +134,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
log.info("vLLM chat completion")
|
||||
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
|
@ -160,7 +150,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.formatter)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(
|
||||
prompt, vllm_sampling_params, request_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue