mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 11:18:40 +00:00
Merge branch 'main' into inference_refactor
This commit is contained in:
commit
fadb7deae5
79 changed files with 1547 additions and 2026 deletions
|
|
@ -25,7 +25,10 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .persistence import AgentPersistence
|
||||
|
|
@ -239,13 +242,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
if len(self.input_shields) > 0:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
||||
|
|
@ -262,13 +266,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
if len(self.output_shields) > 0:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
yield final_response
|
||||
|
||||
|
|
@ -387,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if rag_context:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = "\n".join(rag_context)
|
||||
last_message.context = rag_context
|
||||
|
||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||
|
|
@ -531,106 +536,72 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages = input_messages + [message]
|
||||
else:
|
||||
log.info(f"{str(message)}")
|
||||
try:
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_call = message.tool_calls[0]
|
||||
|
||||
name = tool_call.tool_name
|
||||
if not isinstance(name, BuiltinTool):
|
||||
yield message
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
name = tool_call.tool_name
|
||||
if not isinstance(name, BuiltinTool):
|
||||
yield message
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
)
|
||||
assert (
|
||||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=result_message.call_id,
|
||||
tool_name=result_message.tool_name,
|
||||
content=result_message.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
||||
if out_attachment := interpret_content_as_attachment(
|
||||
result_message.content
|
||||
):
|
||||
|
|
@ -687,7 +658,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _retrieve_context(
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
||||
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
memory = self._memory_tool_definition()
|
||||
|
|
@ -755,11 +726,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
break
|
||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||
|
||||
return [
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
], bank_ids
|
||||
return (
|
||||
concat_interleaved_content(
|
||||
[
|
||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||
*picked,
|
||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||
]
|
||||
),
|
||||
bank_ids,
|
||||
)
|
||||
|
||||
def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
|
|
@ -804,7 +780,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
|||
else:
|
||||
raise ValueError(f"Unsupported URL {url}")
|
||||
|
||||
content.append(f'# There is a file accessible to you at "{filepath}"\n')
|
||||
content.append(
|
||||
TextContentItem(
|
||||
text=f'# There is a file accessible to you at "{filepath}"\n'
|
||||
)
|
||||
)
|
||||
|
||||
return ToolResponseMessage(
|
||||
call_id="",
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ from llama_stack.apis.agents import (
|
|||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
|
|
@ -42,7 +45,7 @@ async def default_rag_query_generator(
|
|||
messages: List[Message],
|
||||
**kwargs,
|
||||
):
|
||||
return config.sep.join(interleaved_text_media_as_str(m.content) for m in messages)
|
||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ import logging
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
||||
)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
model_parallel_is_initialized,
|
||||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
|
|
@ -39,8 +39,8 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
|
|||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
|
|
@ -207,7 +207,7 @@ class Llama:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_input: ModelInput,
|
||||
model_input: LLMInput,
|
||||
max_gen_len: int,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
|
|
@ -344,7 +344,7 @@ class Llama:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
|
|
@ -355,10 +355,7 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
model_input = self.formatter.encode_content(content)
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
|
|
@ -375,10 +372,8 @@ class Llama:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request, self.llama_model)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
|
|
@ -390,7 +385,7 @@ class Llama:
|
|||
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
request.messages,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
|
|
|
|||
|
|
@ -7,24 +7,50 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
TokenLogProbs,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
convert_request_to_raw,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
|
@ -44,7 +70,8 @@ class MetaReferenceInferenceImpl(
|
|||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.model_id = None
|
||||
self.llama_model = None
|
||||
|
||||
async def initialize(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
|
@ -56,20 +83,21 @@ class MetaReferenceInferenceImpl(
|
|||
else:
|
||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model = model_id
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model is None:
|
||||
if self.model_id or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model:
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model}"
|
||||
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
|
||||
)
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
|
|
@ -107,7 +135,7 @@ class MetaReferenceInferenceImpl(
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
|
|
@ -116,6 +144,7 @@ class MetaReferenceInferenceImpl(
|
|||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content = augment_content_with_response_format_prompt(response_format, content)
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
|
|
@ -125,7 +154,7 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
|
|
@ -250,7 +279,13 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await request_with_localized_media(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(
|
||||
request, self.llama_model.core_model_id.value
|
||||
)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
|
|
@ -291,11 +326,15 @@ class MetaReferenceInferenceImpl(
|
|||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
raw_message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=message,
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
|
|
@ -421,31 +460,3 @@ class MetaReferenceInferenceImpl(
|
|||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||
if not request_has_media(request):
|
||||
return request
|
||||
|
||||
async def _convert_single_content(content):
|
||||
if isinstance(content, ImageMedia):
|
||||
url = await convert_image_media_to_url(content, download=True)
|
||||
return ImageMedia(image=URL(uri=url))
|
||||
else:
|
||||
return content
|
||||
|
||||
async def _convert_content(content):
|
||||
if isinstance(content, list):
|
||||
return [await _convert_single_content(c) for c in content]
|
||||
else:
|
||||
return await _convert_single_content(content)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
for m in request.messages:
|
||||
m.content = await _convert_content(m.content)
|
||||
else:
|
||||
request.content = await _convert_content(request.content)
|
||||
|
||||
return request
|
||||
|
|
|
|||
|
|
@ -114,21 +114,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
log.info("vLLM completion")
|
||||
messages = [UserMessage(content=content)]
|
||||
return self.chat_completion(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
raise NotImplementedError("Completion not implemented for vLLM")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
@ -142,8 +134,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
log.info("vLLM chat completion")
|
||||
|
||||
assert self.engine is not None
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
|
|
@ -160,7 +150,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
prompt = await chat_completion_request_to_prompt(request, self.formatter)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(
|
||||
prompt, vllm_sampling_params, request_id
|
||||
|
|
@ -218,8 +208,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self, model_id: str, contents: list[InterleavedTextMedia]
|
||||
self, model_id: str, contents: List[InterleavedContent]
|
||||
) -> EmbeddingsResponse:
|
||||
log.info("vLLM embeddings")
|
||||
# TODO
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -4,12 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import ChromaInlineImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
|
||||
async def get_provider_impl(
|
||||
config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ChromaMemoryAdapter(config)
|
||||
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
|
@ -208,7 +209,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
index = self.cache.get(bank_id)
|
||||
|
|
|
|||
|
|
@ -7,13 +7,17 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||
"CodeScanner",
|
||||
"CodeShield",
|
||||
|
|
@ -48,7 +52,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
||||
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||
result = await CodeShield.scan_code(text)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,13 @@ from typing import Any, Dict, List, Optional
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
|
@ -222,6 +226,8 @@ class LlamaGuardShield:
|
|||
|
||||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
for i, m in enumerate(messages):
|
||||
print(f"{i}: {m.role}: {m.content}")
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
||||
)
|
||||
|
|
@ -258,18 +264,18 @@ class LlamaGuardShield:
|
|||
most_recent_img = None
|
||||
|
||||
for m in messages[::-1]:
|
||||
if isinstance(m.content, str):
|
||||
if isinstance(m.content, str) or isinstance(m.content, TextContentItem):
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, ImageMedia):
|
||||
elif isinstance(m.content, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = m.content
|
||||
conversation.append(m)
|
||||
elif isinstance(m.content, list):
|
||||
content = []
|
||||
for c in m.content:
|
||||
if isinstance(c, str):
|
||||
if isinstance(c, str) or isinstance(c, TextContentItem):
|
||||
content.append(c)
|
||||
elif isinstance(c, ImageMedia):
|
||||
elif isinstance(c, ImageContentItem):
|
||||
if most_recent_img is None and m.role == Role.user.value:
|
||||
most_recent_img = c
|
||||
content.append(c)
|
||||
|
|
@ -292,7 +298,7 @@ class LlamaGuardShield:
|
|||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[
|
||||
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
|
||||
f"{m.role.capitalize()}: {interleaved_content_as_str(m.content)}"
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ from llama_stack.apis.safety import * # noqa: F403
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
|
@ -83,7 +86,7 @@ class PromptGuardShield:
|
|||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
message = messages[-1]
|
||||
text = interleaved_text_media_as_str(message.content)
|
||||
text = interleaved_content_as_str(message.content)
|
||||
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue