fix meta-reference, test vllm

This commit is contained in:
Ashwin Bharambe 2024-12-16 23:45:15 -08:00
parent b75e4eb6b9
commit a30aaaa2e5
3 changed files with 13 additions and 3 deletions

View file

@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
@ -315,11 +316,15 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
@ -454,7 +459,9 @@ async def convert_request_to_raw(
messages = []
for m in request.messages:
content = await interleaved_content_convert_to_raw(m.content)
messages.append(RawMessage(**m.model_dump(), content=content))
d = m.model_dump()
d["content"] = content
messages.append(RawMessage(**d))
request.messages = messages
else:
request.content = await interleaved_content_convert_to_raw(request.content)

View file

@ -113,6 +113,7 @@ def inference_vllm_remote() -> ProviderFixture:
provider_type="remote::vllm",
config=VLLMInferenceAdapterConfig(
url=get_env_or_fail("VLLM_URL"),
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
).model_dump(),
)
],

View file

@ -103,6 +103,8 @@ async def interleaved_content_convert_to_raw(
data = response.content
else:
raise ValueError("Unsupported URL type")
else:
data = c.data
return RawMediaItem(data=data)
else:
raise ValueError(f"Unsupported content type: {type(c)}")