From a30aaaa2e5a2cd39f57f3b001b377e7196dcdc5e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 16 Dec 2024 23:45:15 -0800 Subject: [PATCH] fix meta-reference, test vllm --- .../inline/inference/meta_reference/inference.py | 13 ++++++++++--- llama_stack/providers/tests/inference/fixtures.py | 1 + .../providers/utils/inference/prompt_adapter.py | 2 ++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e1f56af72..4c4e7cb82 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -26,6 +26,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEvent, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, + CompletionMessage, CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, @@ -315,11 +316,15 @@ class MetaReferenceInferenceImpl( if stop_reason is None: stop_reason = StopReason.out_of_tokens - message = self.generator.formatter.decode_assistant_message( + raw_message = self.generator.formatter.decode_assistant_message( tokens, stop_reason ) return ChatCompletionResponse( - completion_message=message, + completion_message=CompletionMessage( + content=raw_message.content, + stop_reason=raw_message.stop_reason, + tool_calls=raw_message.tool_calls, + ), logprobs=logprobs if request.logprobs else None, ) @@ -454,7 +459,9 @@ async def convert_request_to_raw( messages = [] for m in request.messages: content = await interleaved_content_convert_to_raw(m.content) - messages.append(RawMessage(**m.model_dump(), content=content)) + d = m.model_dump() + d["content"] = content + messages.append(RawMessage(**d)) request.messages = messages else: request.content = await interleaved_content_convert_to_raw(request.content) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index fcad03c49..7cc15bd9d 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -113,6 +113,7 @@ def inference_vllm_remote() -> ProviderFixture: provider_type="remote::vllm", config=VLLMInferenceAdapterConfig( url=get_env_or_fail("VLLM_URL"), + max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)), ).model_dump(), ) ], diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index fb6a6dcfc..928b089e0 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -103,6 +103,8 @@ async def interleaved_content_convert_to_raw( data = response.content else: raise ValueError("Unsupported URL type") + else: + data = c.data return RawMediaItem(data=data) else: raise ValueError(f"Unsupported content type: {type(c)}")