mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Fix conversion to RawMessage everywhere
This commit is contained in:
parent
fbca51d6da
commit
b7a7caa9a8
11 changed files with 87 additions and 78 deletions
|
@ -25,6 +25,8 @@ from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
@ -778,7 +780,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
content.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'# There is a file accessible to you at "{filepath}"\n'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
|
|
|
@ -25,7 +25,6 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||||
from llama_models.llama3.api.datatypes import RawContent, RawMessage
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
from llama_models.llama3.reference_impl.model import Transformer
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||||
|
@ -39,6 +38,10 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
ChatCompletionRequestWithRawContent,
|
||||||
|
CompletionRequestWithRawContent,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
Fp8QuantizationConfig,
|
Fp8QuantizationConfig,
|
||||||
|
@ -50,14 +53,6 @@ from .config import (
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
|
||||||
messages: List[RawMessage]
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequestWithRawContent(CompletionRequest):
|
|
||||||
content: RawContent
|
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model) -> str:
|
def model_checkpoint_dir(model) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_models.datatypes import Model
|
from llama_models.datatypes import Model
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
RawMessage,
|
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -53,14 +52,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
augment_content_with_response_format_prompt,
|
augment_content_with_response_format_prompt,
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
interleaved_content_convert_to_raw,
|
convert_request_to_raw,
|
||||||
)
|
)
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import (
|
from .generation import Llama
|
||||||
ChatCompletionRequestWithRawContent,
|
|
||||||
CompletionRequestWithRawContent,
|
|
||||||
Llama,
|
|
||||||
)
|
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -450,20 +445,3 @@ class MetaReferenceInferenceImpl(
|
||||||
else:
|
else:
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
|
||||||
async def convert_request_to_raw(
|
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
|
||||||
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
messages = []
|
|
||||||
for m in request.messages:
|
|
||||||
content = await interleaved_content_convert_to_raw(m.content)
|
|
||||||
d = m.model_dump()
|
|
||||||
d["content"] = content
|
|
||||||
messages.append(RawMessage(**d))
|
|
||||||
request.messages = messages
|
|
||||||
else:
|
|
||||||
request.content = await interleaved_content_convert_to_raw(request.content)
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
|
@ -120,15 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||||
log.info("vLLM completion")
|
raise NotImplementedError("Completion not implemented for vLLM")
|
||||||
messages = [UserMessage(content=content)]
|
|
||||||
return self.chat_completion(
|
|
||||||
model=model_id,
|
|
||||||
messages=messages,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
stream=stream,
|
|
||||||
logprobs=logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -142,8 +134,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||||
log.info("vLLM chat completion")
|
|
||||||
|
|
||||||
assert self.engine is not None
|
assert self.engine is not None
|
||||||
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
@ -160,7 +150,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
request_id = _random_uuid()
|
request_id = _random_uuid()
|
||||||
|
|
||||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
prompt = await chat_completion_request_to_prompt(request, self.formatter)
|
||||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||||
results_generator = self.engine.generate(
|
results_generator = self.engine.generate(
|
||||||
prompt, vllm_sampling_params, request_id
|
prompt, vllm_sampling_params, request_id
|
||||||
|
|
|
@ -94,14 +94,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
|
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if request.sampling_params and request.sampling_params.top_k:
|
if request.sampling_params and request.sampling_params.top_k:
|
||||||
|
@ -167,11 +167,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
prompt = chat_completion_request_to_prompt(
|
prompt = await chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
elif isinstance(request, CompletionRequest):
|
elif isinstance(request, CompletionRequest):
|
||||||
prompt = completion_request_to_prompt(request, self.formatter)
|
prompt = await completion_request_to_prompt(request, self.formatter)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown request type {type(request)}")
|
raise ValueError(f"Unknown request type {type(request)}")
|
||||||
|
|
||||||
|
|
|
@ -241,14 +241,16 @@ class FireworksInferenceAdapter(
|
||||||
await convert_message_to_openai_dict(m) for m in request.messages
|
await convert_message_to_openai_dict(m) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Fireworks does not support media for Completion requests"
|
), "Fireworks does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
if "prompt" in input_dict:
|
if "prompt" in input_dict:
|
||||||
|
|
|
@ -243,7 +243,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
self.register_helper.get_llama_model(request.model),
|
||||||
self.formatter,
|
self.formatter,
|
||||||
|
@ -252,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Ollama does not support media for Completion requests"
|
), "Ollama does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -130,8 +130,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||||
prompt, input_tokens = completion_request_to_prompt_model_input_info(
|
prompt, input_tokens = await completion_request_to_prompt_model_input_info(
|
||||||
request, self.formatter
|
request, self.formatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params_for_completion(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = await self.client.text_generation(**params)
|
s = await self.client.text_generation(**params)
|
||||||
|
@ -169,7 +169,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params_for_completion(request)
|
||||||
r = await self.client.text_generation(**params)
|
r = await self.client.text_generation(**params)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
@ -216,7 +216,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.text_generation(**params)
|
r = await self.client.text_generation(**params)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
@ -231,7 +231,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = await self.client.text_generation(**params)
|
s = await self.client.text_generation(**params)
|
||||||
|
@ -249,8 +249,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
prompt, input_tokens = await chat_completion_request_to_model_input_info(
|
||||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
request, self.register_helper.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
return dict(
|
return dict(
|
||||||
|
|
|
@ -233,14 +233,16 @@ class TogetherInferenceAdapter(
|
||||||
await convert_message_to_openai_dict(m) for m in request.messages
|
await convert_message_to_openai_dict(m) for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Together does not support media for Completion requests"
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
|
|
|
@ -77,7 +77,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError("Completion not implemented for vLLM")
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -167,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
for m in request.messages
|
for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
self.register_helper.get_llama_model(request.model),
|
||||||
self.formatter,
|
self.formatter,
|
||||||
|
@ -176,7 +176,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Together does not support media for Completion requests"
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
self.register_helper.get_llama_model(request.model),
|
||||||
self.formatter,
|
self.formatter,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_models.llama3.api.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
RawContentItem,
|
RawContentItem,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
Role,
|
Role,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
@ -58,6 +59,14 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||||
|
messages: List[RawMessage]
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequestWithRawContent(CompletionRequest):
|
||||||
|
content: RawContent
|
||||||
|
|
||||||
|
|
||||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||||
def _process(c) -> str:
|
def _process(c) -> str:
|
||||||
if isinstance(c, str):
|
if isinstance(c, str):
|
||||||
|
@ -75,6 +84,23 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s
|
||||||
return _process(content)
|
return _process(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_request_to_raw(
|
||||||
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
messages = []
|
||||||
|
for m in request.messages:
|
||||||
|
content = await interleaved_content_convert_to_raw(m.content)
|
||||||
|
d = m.model_dump()
|
||||||
|
d["content"] = content
|
||||||
|
messages.append(RawMessage(**d))
|
||||||
|
request.messages = messages
|
||||||
|
else:
|
||||||
|
request.content = await interleaved_content_convert_to_raw(request.content)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
async def interleaved_content_convert_to_raw(
|
async def interleaved_content_convert_to_raw(
|
||||||
content: InterleavedContent,
|
content: InterleavedContent,
|
||||||
) -> RawContent:
|
) -> RawContent:
|
||||||
|
@ -169,23 +195,27 @@ async def convert_image_content_to_url(
|
||||||
return base64.b64encode(content).decode("utf-8")
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def completion_request_to_prompt(
|
async def completion_request_to_prompt(
|
||||||
request: CompletionRequest, formatter: ChatFormat
|
request: CompletionRequest, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
content = augment_content_with_response_format_prompt(
|
content = augment_content_with_response_format_prompt(
|
||||||
request.response_format, request.content
|
request.response_format, request.content
|
||||||
)
|
)
|
||||||
model_input = formatter.encode_content(content)
|
request.content = content
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
model_input = formatter.encode_content(request.content)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
def completion_request_to_prompt_model_input_info(
|
async def completion_request_to_prompt_model_input_info(
|
||||||
request: CompletionRequest, formatter: ChatFormat
|
request: CompletionRequest, formatter: ChatFormat
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
content = augment_content_with_response_format_prompt(
|
content = augment_content_with_response_format_prompt(
|
||||||
request.response_format, request.content
|
request.response_format, request.content
|
||||||
)
|
)
|
||||||
model_input = formatter.encode_content(content)
|
request.content = content
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
model_input = formatter.encode_content(request.content)
|
||||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,19 +229,23 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_prompt(
|
async def chat_completion_request_to_prompt(
|
||||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
request.messages = messages
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_request_to_model_input_info(
|
async def chat_completion_request_to_model_input_info(
|
||||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
model_input = formatter.encode_dialog_prompt(messages)
|
request.messages = messages
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
len(model_input.tokens),
|
len(model_input.tokens),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue