forked from phoenix-oss/llama-stack-mirror
		
	## What does this PR do?
This is a long-pending change and particularly important to get done
now.
Specifically:
- we cannot "localize" (aka download) any URLs from media attachments
anywhere near our modeling code. it must be done within llama-stack.
- `PIL.Image` is infesting all our APIs via `ImageMedia ->
InterleavedTextMedia` and that cannot be right at all. Anything in the
API surface must be "naturally serializable". We need a standard `{
type: "image", image_url: "<...>" }` which is more extensible
- `UserMessage`, `SystemMessage`, etc. are moved completely to
llama-stack from the llama-models repository.
See https://github.com/meta-llama/llama-models/pull/244 for the
corresponding PR in llama-models.
## Test Plan
```bash
cd llama_stack/providers/tests
pytest -s -v -k "fireworks or ollama or together" inference/test_vision_inference.py
pytest -s -v -k "(fireworks or ollama or together) and llama_3b" inference/test_text_inference.py
pytest -s -v -k chroma memory/test_memory.py \
  --env EMBEDDING_DIMENSION=384 --env CHROMA_DB_PATH=/tmp/foobar
pytest -s -v -k fireworks agents/test_agents.py  \
   --safety-shield=meta-llama/Llama-Guard-3-8B \
   --inference-model=meta-llama/Llama-3.1-8B-Instruct
```
Updated the client sdk (see PR ...), installed the SDK in the same
environment and then ran the SDK tests:
```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=together pytest -s -v agents/test_agents.py
LLAMA_STACK_CONFIG=ollama pytest -s -v memory/test_memory.py
# this one needed a bit of hacking in the run.yaml to ensure I could register the vision model correctly
INFERENCE_MODEL=llama3.2-vision:latest LLAMA_STACK_CONFIG=ollama pytest -s -v inference/test_inference.py
```
		
	
			
		
			
				
	
	
		
			469 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			469 lines
		
	
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import asyncio
 | |
| import logging
 | |
| 
 | |
| from typing import AsyncGenerator, List, Optional, Union
 | |
| 
 | |
| from llama_models.datatypes import Model
 | |
| 
 | |
| from llama_models.llama3.api.datatypes import (
 | |
|     RawMessage,
 | |
|     SamplingParams,
 | |
|     StopReason,
 | |
|     ToolDefinition,
 | |
|     ToolPromptFormat,
 | |
| )
 | |
| from llama_models.sku_list import resolve_model
 | |
| 
 | |
| from llama_stack.apis.inference import (
 | |
|     ChatCompletionRequest,
 | |
|     ChatCompletionResponse,
 | |
|     ChatCompletionResponseEvent,
 | |
|     ChatCompletionResponseEventType,
 | |
|     ChatCompletionResponseStreamChunk,
 | |
|     CompletionMessage,
 | |
|     CompletionRequest,
 | |
|     CompletionResponse,
 | |
|     CompletionResponseStreamChunk,
 | |
|     Inference,
 | |
|     InterleavedContent,
 | |
|     LogProbConfig,
 | |
|     Message,
 | |
|     ResponseFormat,
 | |
|     TokenLogProbs,
 | |
|     ToolCallDelta,
 | |
|     ToolCallParseStatus,
 | |
|     ToolChoice,
 | |
| )
 | |
| 
 | |
| from llama_stack.apis.models import ModelType
 | |
| from llama_stack.providers.datatypes import ModelsProtocolPrivate
 | |
| from llama_stack.providers.utils.inference.embedding_mixin import (
 | |
|     SentenceTransformerEmbeddingMixin,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.model_registry import (
 | |
|     build_model_alias,
 | |
|     ModelRegistryHelper,
 | |
| )
 | |
| from llama_stack.providers.utils.inference.prompt_adapter import (
 | |
|     augment_content_with_response_format_prompt,
 | |
|     chat_completion_request_to_messages,
 | |
|     interleaved_content_convert_to_raw,
 | |
| )
 | |
| from .config import MetaReferenceInferenceConfig
 | |
| from .generation import (
 | |
|     ChatCompletionRequestWithRawContent,
 | |
|     CompletionRequestWithRawContent,
 | |
|     Llama,
 | |
| )
 | |
| from .model_parallel import LlamaModelParallelGenerator
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| # there's a single model parallel process running serving the model. for now,
 | |
| # we don't support multiple concurrent requests to this process.
 | |
| SEMAPHORE = asyncio.Semaphore(1)
 | |
| 
 | |
| 
 | |
| class MetaReferenceInferenceImpl(
 | |
|     SentenceTransformerEmbeddingMixin,
 | |
|     Inference,
 | |
|     ModelsProtocolPrivate,
 | |
| ):
 | |
|     def __init__(self, config: MetaReferenceInferenceConfig) -> None:
 | |
|         self.config = config
 | |
|         model = resolve_model(config.model)
 | |
|         if model is None:
 | |
|             raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
 | |
|         self.model_registry_helper = ModelRegistryHelper(
 | |
|             [
 | |
|                 build_model_alias(
 | |
|                     model.descriptor(),
 | |
|                     model.core_model_id.value,
 | |
|                 )
 | |
|             ],
 | |
|         )
 | |
|         self.model = model
 | |
|         # verify that the checkpoint actually is for this model lol
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         log.info(f"Loading model `{self.model.descriptor()}`")
 | |
|         if self.config.create_distributed_process_group:
 | |
|             self.generator = LlamaModelParallelGenerator(self.config)
 | |
|             self.generator.start()
 | |
|         else:
 | |
|             self.generator = Llama.build(self.config)
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         if self.config.create_distributed_process_group:
 | |
|             self.generator.stop()
 | |
| 
 | |
|     def check_model(self, request) -> None:
 | |
|         model = resolve_model(request.model)
 | |
|         if model is None:
 | |
|             raise RuntimeError(
 | |
|                 f"Unknown model: {request.model}, Run `llama model list`"
 | |
|             )
 | |
|         elif model.descriptor() != self.model.descriptor():
 | |
|             raise RuntimeError(
 | |
|                 f"Model mismatch: {request.model} != {self.model.descriptor()}"
 | |
|             )
 | |
| 
 | |
|     async def unregister_model(self, model_id: str) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def register_model(self, model: Model) -> Model:
 | |
|         model = await self.model_registry_helper.register_model(model)
 | |
|         if model.model_type == ModelType.embedding:
 | |
|             self._load_sentence_transformer_model(model.provider_resource_id)
 | |
|         return model
 | |
| 
 | |
|     async def completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         content: InterleavedContent,
 | |
|         sampling_params: Optional[SamplingParams] = SamplingParams(),
 | |
|         response_format: Optional[ResponseFormat] = None,
 | |
|         stream: Optional[bool] = False,
 | |
|         logprobs: Optional[LogProbConfig] = None,
 | |
|     ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         content = augment_content_with_response_format_prompt(response_format, content)
 | |
|         request = CompletionRequest(
 | |
|             model=model_id,
 | |
|             content=content,
 | |
|             sampling_params=sampling_params,
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         self.check_model(request)
 | |
|         request = await convert_request_to_raw(request)
 | |
| 
 | |
|         if request.stream:
 | |
|             return self._stream_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_completion(request)
 | |
| 
 | |
|     async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
 | |
|         def impl():
 | |
|             stop_reason = None
 | |
| 
 | |
|             for token_result in self.generator.completion(request):
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 logprobs = None
 | |
|                 if stop_reason is None:
 | |
|                     if request.logprobs:
 | |
|                         assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                         logprobs = [
 | |
|                             TokenLogProbs(
 | |
|                                 logprobs_by_token={
 | |
|                                     token_result.text: token_result.logprobs[0]
 | |
|                                 }
 | |
|                             )
 | |
|                         ]
 | |
| 
 | |
|                 yield CompletionResponseStreamChunk(
 | |
|                     delta=text,
 | |
|                     stop_reason=stop_reason,
 | |
|                     logprobs=logprobs if request.logprobs else None,
 | |
|                 )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 yield CompletionResponseStreamChunk(
 | |
|                     delta="",
 | |
|                     stop_reason=StopReason.out_of_tokens,
 | |
|                 )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 for x in impl():
 | |
|                     yield x
 | |
|         else:
 | |
|             for x in impl():
 | |
|                 yield x
 | |
| 
 | |
|     async def _nonstream_completion(
 | |
|         self, request: CompletionRequest
 | |
|     ) -> CompletionResponse:
 | |
|         def impl():
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
| 
 | |
|             tokenizer = self.generator.formatter.tokenizer
 | |
|             for token_result in self.generator.completion(request):
 | |
|                 tokens.append(token_result.token)
 | |
| 
 | |
|                 if token_result.token in tokenizer.stop_tokens:
 | |
|                     # not quite right semantically
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
| 
 | |
|                 if request.logprobs:
 | |
|                     assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                     logprobs.append(
 | |
|                         TokenLogProbs(
 | |
|                             logprobs_by_token={
 | |
|                                 token_result.text: token_result.logprobs[0]
 | |
|                             }
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             content = self.generator.formatter.tokenizer.decode(tokens)
 | |
|             return CompletionResponse(
 | |
|                 content=content,
 | |
|                 stop_reason=stop_reason,
 | |
|                 logprobs=logprobs if request.logprobs else None,
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 return impl()
 | |
|         else:
 | |
|             return impl()
 | |
| 
 | |
|     async def chat_completion(
 | |
|         self,
 | |
|         model_id: str,
 | |
|         messages: List[Message],
 | |
|         sampling_params: Optional[SamplingParams] = SamplingParams(),
 | |
|         response_format: Optional[ResponseFormat] = None,
 | |
|         tools: Optional[List[ToolDefinition]] = None,
 | |
|         tool_choice: Optional[ToolChoice] = ToolChoice.auto,
 | |
|         tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
 | |
|         stream: Optional[bool] = False,
 | |
|         logprobs: Optional[LogProbConfig] = None,
 | |
|     ) -> AsyncGenerator:
 | |
|         if logprobs:
 | |
|             assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
 | |
| 
 | |
|         # wrapper request to make it easier to pass around (internal only, not exposed to API)
 | |
|         request = ChatCompletionRequest(
 | |
|             model=model_id,
 | |
|             messages=messages,
 | |
|             sampling_params=sampling_params,
 | |
|             tools=tools or [],
 | |
|             tool_choice=tool_choice,
 | |
|             tool_prompt_format=tool_prompt_format,
 | |
|             response_format=response_format,
 | |
|             stream=stream,
 | |
|             logprobs=logprobs,
 | |
|         )
 | |
|         self.check_model(request)
 | |
| 
 | |
|         # augment and rewrite messages depending on the model
 | |
|         request.messages = chat_completion_request_to_messages(
 | |
|             request, self.model.core_model_id.value
 | |
|         )
 | |
|         # download media and convert to raw content so we can send it to the model
 | |
|         request = await convert_request_to_raw(request)
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             if SEMAPHORE.locked():
 | |
|                 raise RuntimeError("Only one concurrent request is supported")
 | |
| 
 | |
|         if request.stream:
 | |
|             return self._stream_chat_completion(request)
 | |
|         else:
 | |
|             return await self._nonstream_chat_completion(request)
 | |
| 
 | |
|     async def _nonstream_chat_completion(
 | |
|         self, request: ChatCompletionRequest
 | |
|     ) -> ChatCompletionResponse:
 | |
|         def impl():
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
| 
 | |
|             for token_result in self.generator.chat_completion(request):
 | |
|                 tokens.append(token_result.token)
 | |
| 
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
| 
 | |
|                 if request.logprobs:
 | |
|                     assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                     logprobs.append(
 | |
|                         TokenLogProbs(
 | |
|                             logprobs_by_token={
 | |
|                                 token_result.text: token_result.logprobs[0]
 | |
|                             }
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             raw_message = self.generator.formatter.decode_assistant_message(
 | |
|                 tokens, stop_reason
 | |
|             )
 | |
|             return ChatCompletionResponse(
 | |
|                 completion_message=CompletionMessage(
 | |
|                     content=raw_message.content,
 | |
|                     stop_reason=raw_message.stop_reason,
 | |
|                     tool_calls=raw_message.tool_calls,
 | |
|                 ),
 | |
|                 logprobs=logprobs if request.logprobs else None,
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 return impl()
 | |
|         else:
 | |
|             return impl()
 | |
| 
 | |
|     async def _stream_chat_completion(
 | |
|         self, request: ChatCompletionRequest
 | |
|     ) -> AsyncGenerator:
 | |
|         def impl():
 | |
|             yield ChatCompletionResponseStreamChunk(
 | |
|                 event=ChatCompletionResponseEvent(
 | |
|                     event_type=ChatCompletionResponseEventType.start,
 | |
|                     delta="",
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|             tokens = []
 | |
|             logprobs = []
 | |
|             stop_reason = None
 | |
|             ipython = False
 | |
| 
 | |
|             for token_result in self.generator.chat_completion(request):
 | |
|                 tokens.append(token_result.token)
 | |
| 
 | |
|                 if not ipython and token_result.text.startswith("<|python_tag|>"):
 | |
|                     ipython = True
 | |
|                     yield ChatCompletionResponseStreamChunk(
 | |
|                         event=ChatCompletionResponseEvent(
 | |
|                             event_type=ChatCompletionResponseEventType.progress,
 | |
|                             delta=ToolCallDelta(
 | |
|                                 content="",
 | |
|                                 parse_status=ToolCallParseStatus.started,
 | |
|                             ),
 | |
|                         )
 | |
|                     )
 | |
|                     continue
 | |
| 
 | |
|                 if token_result.text == "<|eot_id|>":
 | |
|                     stop_reason = StopReason.end_of_turn
 | |
|                     text = ""
 | |
|                 elif token_result.text == "<|eom_id|>":
 | |
|                     stop_reason = StopReason.end_of_message
 | |
|                     text = ""
 | |
|                 else:
 | |
|                     text = token_result.text
 | |
| 
 | |
|                 if ipython:
 | |
|                     delta = ToolCallDelta(
 | |
|                         content=text,
 | |
|                         parse_status=ToolCallParseStatus.in_progress,
 | |
|                     )
 | |
|                 else:
 | |
|                     delta = text
 | |
| 
 | |
|                 if stop_reason is None:
 | |
|                     if request.logprobs:
 | |
|                         assert len(token_result.logprobs) == 1
 | |
| 
 | |
|                         logprobs.append(
 | |
|                             TokenLogProbs(
 | |
|                                 logprobs_by_token={
 | |
|                                     token_result.text: token_result.logprobs[0]
 | |
|                                 }
 | |
|                             )
 | |
|                         )
 | |
|                     yield ChatCompletionResponseStreamChunk(
 | |
|                         event=ChatCompletionResponseEvent(
 | |
|                             event_type=ChatCompletionResponseEventType.progress,
 | |
|                             delta=delta,
 | |
|                             stop_reason=stop_reason,
 | |
|                             logprobs=logprobs if request.logprobs else None,
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|             if stop_reason is None:
 | |
|                 stop_reason = StopReason.out_of_tokens
 | |
| 
 | |
|             message = self.generator.formatter.decode_assistant_message(
 | |
|                 tokens, stop_reason
 | |
|             )
 | |
| 
 | |
|             parsed_tool_calls = len(message.tool_calls) > 0
 | |
|             if ipython and not parsed_tool_calls:
 | |
|                 yield ChatCompletionResponseStreamChunk(
 | |
|                     event=ChatCompletionResponseEvent(
 | |
|                         event_type=ChatCompletionResponseEventType.progress,
 | |
|                         delta=ToolCallDelta(
 | |
|                             content="",
 | |
|                             parse_status=ToolCallParseStatus.failure,
 | |
|                         ),
 | |
|                         stop_reason=stop_reason,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             for tool_call in message.tool_calls:
 | |
|                 yield ChatCompletionResponseStreamChunk(
 | |
|                     event=ChatCompletionResponseEvent(
 | |
|                         event_type=ChatCompletionResponseEventType.progress,
 | |
|                         delta=ToolCallDelta(
 | |
|                             content=tool_call,
 | |
|                             parse_status=ToolCallParseStatus.success,
 | |
|                         ),
 | |
|                         stop_reason=stop_reason,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|             yield ChatCompletionResponseStreamChunk(
 | |
|                 event=ChatCompletionResponseEvent(
 | |
|                     event_type=ChatCompletionResponseEventType.complete,
 | |
|                     delta="",
 | |
|                     stop_reason=stop_reason,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|         if self.config.create_distributed_process_group:
 | |
|             async with SEMAPHORE:
 | |
|                 for x in impl():
 | |
|                     yield x
 | |
|         else:
 | |
|             for x in impl():
 | |
|                 yield x
 | |
| 
 | |
| 
 | |
| async def convert_request_to_raw(
 | |
|     request: Union[ChatCompletionRequest, CompletionRequest],
 | |
| ) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]:
 | |
|     if isinstance(request, ChatCompletionRequest):
 | |
|         messages = []
 | |
|         for m in request.messages:
 | |
|             content = await interleaved_content_convert_to_raw(m.content)
 | |
|             d = m.model_dump()
 | |
|             d["content"] = content
 | |
|             messages.append(RawMessage(**d))
 | |
|         request.messages = messages
 | |
|     else:
 | |
|         request.content = await interleaved_content_convert_to_raw(request.content)
 | |
| 
 | |
|     return request
 |