mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
fix meta-reference, test vllm
This commit is contained in:
parent
b75e4eb6b9
commit
a30aaaa2e5
3 changed files with 13 additions and 3 deletions
|
@ -26,6 +26,7 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionMessage,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
|
@ -315,11 +316,15 @@ class MetaReferenceInferenceImpl(
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
message = self.generator.formatter.decode_assistant_message(
|
raw_message = self.generator.formatter.decode_assistant_message(
|
||||||
tokens, stop_reason
|
tokens, stop_reason
|
||||||
)
|
)
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=message,
|
completion_message=CompletionMessage(
|
||||||
|
content=raw_message.content,
|
||||||
|
stop_reason=raw_message.stop_reason,
|
||||||
|
tool_calls=raw_message.tool_calls,
|
||||||
|
),
|
||||||
logprobs=logprobs if request.logprobs else None,
|
logprobs=logprobs if request.logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -454,7 +459,9 @@ async def convert_request_to_raw(
|
||||||
messages = []
|
messages = []
|
||||||
for m in request.messages:
|
for m in request.messages:
|
||||||
content = await interleaved_content_convert_to_raw(m.content)
|
content = await interleaved_content_convert_to_raw(m.content)
|
||||||
messages.append(RawMessage(**m.model_dump(), content=content))
|
d = m.model_dump()
|
||||||
|
d["content"] = content
|
||||||
|
messages.append(RawMessage(**d))
|
||||||
request.messages = messages
|
request.messages = messages
|
||||||
else:
|
else:
|
||||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||||
|
|
|
@ -113,6 +113,7 @@ def inference_vllm_remote() -> ProviderFixture:
|
||||||
provider_type="remote::vllm",
|
provider_type="remote::vllm",
|
||||||
config=VLLMInferenceAdapterConfig(
|
config=VLLMInferenceAdapterConfig(
|
||||||
url=get_env_or_fail("VLLM_URL"),
|
url=get_env_or_fail("VLLM_URL"),
|
||||||
|
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", 2048)),
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
@ -103,6 +103,8 @@ async def interleaved_content_convert_to_raw(
|
||||||
data = response.content
|
data = response.content
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported URL type")
|
raise ValueError("Unsupported URL type")
|
||||||
|
else:
|
||||||
|
data = c.data
|
||||||
return RawMediaItem(data=data)
|
return RawMediaItem(data=data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported content type: {type(c)}")
|
raise ValueError(f"Unsupported content type: {type(c)}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue