mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 16:22:46 +00:00
Merge branch 'meta-llama:main' into qdrant
This commit is contained in:
commit
1575578446
101 changed files with 3310 additions and 722 deletions
|
|
@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
return tool_config
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@
|
|||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
assert isinstance(
|
||||
config, DatabricksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -19,4 +18,4 @@ class DatabricksImplConfig(BaseModel):
|
|||
api_token: str = Field(
|
||||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,10 +48,17 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -77,14 +84,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
|
|
@ -98,7 +105,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -87,14 +87,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
|
|
@ -103,7 +103,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
stream = client.completion.acreate(**params)
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
OLLAMA_SUPPORTED_MODELS = {
|
||||
|
|
@ -33,7 +36,8 @@ OLLAMA_SUPPORTED_MODELS = {
|
|||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
|
||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||
"Llama-Guard-3-8B": "llama-guard3:8b",
|
||||
"Llama-Guard-3-1B": "llama-guard3:1b",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -84,7 +88,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
return ret
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -92,9 +96,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def chat_completion(
|
||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
sampling_options = get_sampling_options(request)
|
||||
# This is needed since the Ollama API expects num_predict to be set
|
||||
# for early truncation instead of max_tokens.
|
||||
if sampling_options["max_tokens"] is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
"prompt": completion_request_to_prompt(request, self.formatter),
|
||||
"options": sampling_options,
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
text=chunk["response"],
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
r = await self.client.generate(**params)
|
||||
assert isinstance(r, dict)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
text=r["response"],
|
||||
)
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -118,7 +179,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
|
|
@ -143,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
|
|
@ -163,7 +224,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
|
|
@ -116,7 +116,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
|
|
@ -135,7 +135,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
|
|||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -101,14 +101,14 @@ class TogetherInferenceAdapter(
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Together
|
||||
|
|
@ -123,7 +123,7 @@ class TogetherInferenceAdapter(
|
|||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
15
llama_stack/providers/adapters/inference/vllm/__init__.py
Normal file
15
llama_stack/providers/adapters/inference/vllm/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import VLLMImplConfig
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMImplConfig, _deps):
|
||||
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = VLLMInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
22
llama_stack/providers/adapters/inference/vllm/config.py
Normal file
22
llama_stack/providers/adapters/inference/vllm/config.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMImplConfig(BaseModel):
|
||||
url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The API token",
|
||||
)
|
||||
152
llama_stack/providers/adapters/inference/vllm/vllm.py
Normal file
152
llama_stack/providers/adapters/inference/vllm/vllm.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMImplConfig
|
||||
|
||||
VLLM_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
|
||||
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
|
||||
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
|
||||
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
|
||||
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
|
||||
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
|
||||
}
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: VLLMImplConfig) -> None:
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for vLLM models")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return [
|
||||
ModelDef(identifier=model.id, llama_model=model.id)
|
||||
for model in self.client.models.list()
|
||||
]
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, self.client)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, self.client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(request, r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
|
||||
# generator so this wrapper is not necessary?
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": VLLM_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason = None
|
||||
|
||||
with tracing.span("inference"):
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=self._get_tools(),
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
def create_agent_turn(
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
|
|||
|
|
@ -17,13 +17,22 @@ from llama_stack.providers.utils.inference import supported_inference_models
|
|||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Llama3.1-8B-Instruct",
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
# (including our testing code) who might be using llama-stack as a library.
|
||||
create_distributed_process_group: bool = True
|
||||
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
|
|
|||
|
|
@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
InterleavedTextMedia,
|
||||
Message,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
|
|
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
|
@ -98,7 +97,10 @@ class Llama:
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
if config.checkpoint_dir:
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
|
|
@ -119,9 +121,7 @@ class Llama:
|
|||
**params,
|
||||
)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
|
@ -138,7 +138,7 @@ class Llama:
|
|||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config)
|
||||
model = convert_to_quantized_model(model, config, ckpt_dir)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
|
@ -170,14 +170,16 @@ class Llama:
|
|||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# input_tokens = [
|
||||
# self.formatter.vision_token if t == 128256 else t
|
||||
# for t in model_input.tokens
|
||||
# ]
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
if print_input_tokens:
|
||||
input_tokens = [
|
||||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
]
|
||||
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
|
|
@ -228,8 +230,7 @@ class Llama:
|
|||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(
|
||||
|
|
@ -295,15 +296,12 @@ class Llama:
|
|||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def text_completion(
|
||||
def completion(
|
||||
self,
|
||||
content: InterleavedTextMedia,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
request: CompletionRequest,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
|
@ -311,26 +309,25 @@ class Llama:
|
|||
):
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(content)
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if (
|
||||
max_gen_len is None
|
||||
or max_gen_len == 0
|
||||
|
|
@ -341,12 +338,12 @@ class Llama:
|
|||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
tool_prompt_format,
|
||||
request.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=logprobs,
|
||||
temperature=sampling_params.temperature,
|
||||
top_p=sampling_params.top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,11 +13,9 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
|
|
@ -36,8 +34,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Loading model `{self.model.descriptor()}`")
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama.build(self.config)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
|
@ -51,9 +52,21 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
]
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def completion(
|
||||
def check_model(self, request) -> None:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -61,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
def chat_completion(
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
logprobs = None
|
||||
if stop_reason is None:
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs = [
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.token in tokenizer.stop_tokens:
|
||||
# not quite right semantically
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(
|
||||
TokenLogProbs(
|
||||
logprobs_by_token={
|
||||
token_result.text: token_result.logprobs[0]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
|
|
@ -88,43 +206,26 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
async with SEMAPHORE:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
|
@ -154,12 +255,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
return impl()
|
||||
else:
|
||||
return impl()
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async with SEMAPHORE:
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
|
@ -172,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(
|
||||
messages=messages,
|
||||
temperature=request.sampling_params.temperature,
|
||||
top_p=request.sampling_params.top_p,
|
||||
max_gen_len=request.sampling_params.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
tool_prompt_format=request.tool_prompt_format,
|
||||
):
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
|
|
@ -272,6 +370,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
for x in impl():
|
||||
yield x
|
||||
else:
|
||||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
|||
|
|
@ -7,16 +7,17 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Generator, List, Optional
|
||||
from typing import Any, Generator
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama, model_checkpoint_dir
|
||||
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
|
@ -24,15 +25,13 @@ class ModelRunner:
|
|||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, task: InferenceArgs):
|
||||
return self.llama.chat_completion(
|
||||
task.messages,
|
||||
task.temperature,
|
||||
task.top_p,
|
||||
task.max_gen_len,
|
||||
task.logprobs,
|
||||
task.tool_prompt_format,
|
||||
)
|
||||
def __call__(self, req: Any):
|
||||
if isinstance(req, ChatCompletionRequest):
|
||||
return self.llama.chat_completion(req)
|
||||
elif isinstance(req, CompletionRequest):
|
||||
return self.llama.completion(req)
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {type(req)}")
|
||||
|
||||
|
||||
def init_model_cb(config: MetaReferenceInferenceConfig):
|
||||
|
|
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
|
|||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.group.stop()
|
||||
|
||||
def chat_completion(
|
||||
def completion(
|
||||
self,
|
||||
messages: List[Message],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
request: CompletionRequest,
|
||||
) -> Generator:
|
||||
req_obj = InferenceArgs(
|
||||
messages=deepcopy(messages),
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs or False,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
yield from gen
|
||||
|
|
|
|||
|
|
@ -4,6 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
|
|
@ -11,10 +17,9 @@ import tempfile
|
|||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Callable, Generator, List, Literal, Optional, Union
|
||||
from typing import Callable, Generator, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import zmq
|
||||
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
|
|
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
||||
class InferenceArgs(BaseModel):
|
||||
messages: List[Message]
|
||||
temperature: float
|
||||
top_p: float
|
||||
max_gen_len: int
|
||||
logprobs: bool
|
||||
tool_prompt_format: ToolPromptFormat
|
||||
|
||||
|
||||
class ProcessingMessageName(str, Enum):
|
||||
ready_request = "ready_request"
|
||||
ready_response = "ready_response"
|
||||
|
|
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
|
|||
type: Literal[ProcessingMessageName.task_request] = (
|
||||
ProcessingMessageName.task_request
|
||||
)
|
||||
task: InferenceArgs
|
||||
task: Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
|
|
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
|
|||
self.process.join()
|
||||
self.started = False
|
||||
|
||||
def run_inference(self, inference_args: InferenceArgs) -> Generator:
|
||||
def run_inference(
|
||||
self, req: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
self.running = True
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
|
||||
self.request_socket.send(encode_msg(TaskRequest(task=req)))
|
||||
try:
|
||||
while True:
|
||||
obj_json = self.request_socket.recv()
|
||||
|
|
|
|||
|
|
@ -13,9 +13,10 @@ from typing import Optional
|
|||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -39,6 +40,7 @@ def swiglu_wrapper(
|
|||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
|
|
@ -49,12 +51,14 @@ def convert_to_quantized_model(
|
|||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
llama_model = resolve_model(config.model)
|
||||
assert llama_model is not None, f"Model {config.model} not found"
|
||||
|
||||
# Move weights to GPU with quantization
|
||||
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
cprint("Loading fp8 scales...", "yellow")
|
||||
fp8_scales_path = os.path.join(
|
||||
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
|
||||
)
|
||||
assert os.path.isfile(
|
||||
fp8_scales_path
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
for i in range(1, len(messages)):
|
||||
if messages[i].role == messages[i - 1].role:
|
||||
raise ValueError(
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
|
||||
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
|
||||
)
|
||||
return messages
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in self.inference_api.chat_completion(
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=True,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
def completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
|
|
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
|
|
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
if stream:
|
||||
return self._stream_chat_completion(request, results_generator)
|
||||
else:
|
||||
return self._nonstream_chat_completion(request, results_generator)
|
||||
return await self._nonstream_chat_completion(request, results_generator)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
|
|
@ -207,7 +207,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
return process_chat_completion_response(request, response, self.formatter)
|
||||
return process_chat_completion_response(response, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
|
|
@ -229,7 +229,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
|||
|
|
@ -55,11 +55,20 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama"],
|
||||
pip_packages=["ollama", "aiohttp"],
|
||||
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
),
|
||||
),
|
||||
# remote_provider_spec(
|
||||
# api=Api.inference,
|
||||
# adapter=AdapterSpec(
|
||||
# adapter_type="vllm",
|
||||
# pip_packages=["openai"],
|
||||
# module="llama_stack.providers.adapters.inference.vllm",
|
||||
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig",
|
||||
# ),
|
||||
# ),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
|||
|
|
@ -31,4 +31,4 @@ providers:
|
|||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: /Users/ashwin/.llama/runtime/kvstore.db
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
|
|
|
|||
|
|
@ -64,6 +64,24 @@ def search_query_messages():
|
|||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def attachment_message():
|
||||
return [
|
||||
UserMessage(
|
||||
content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def query_attachment_messages():
|
||||
return [
|
||||
UserMessage(
|
||||
content="What are the top 5 topics that were explained? Only list succinct bullet points."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
|
@ -98,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
|||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
|
@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
|||
assert len(final_event.turn.output_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
agents_settings, attachment_message, query_attachment_messages
|
||||
):
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
# Create a second turn querying the agent
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=query_attachment_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
agents_settings, search_query_messages
|
||||
|
|
@ -169,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
|
|||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@ providers:
|
|||
config:
|
||||
host: localhost
|
||||
port: 11434
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama3.2-1B-Instruct
|
||||
- provider_id: test-tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -50,14 +51,17 @@ def get_expected_stop_reason(model: str):
|
|||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
||||
if "MODEL_IDS" not in os.environ:
|
||||
MODEL_IDS = [Llama_8B, Llama_3B]
|
||||
else:
|
||||
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
|
||||
|
||||
|
||||
# This is going to create multiple Stack impls without tearing down the previous one
|
||||
# Fix that!
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
{"model": Llama_8B},
|
||||
{"model": Llama_3B},
|
||||
],
|
||||
params=[{"model": m} for m in MODEL_IDS],
|
||||
ids=lambda d: d["model"],
|
||||
)
|
||||
async def inference_settings(request):
|
||||
|
|
@ -122,6 +126,48 @@ async def test_model_list(inference_settings):
|
|||
assert model_def.identifier == params["model"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=False,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert "violets are blue" in response.content
|
||||
|
||||
chunks = [
|
||||
r
|
||||
async for r in await inference_impl.completion(
|
||||
content="Roses are red,",
|
||||
stream=True,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||
assert len(chunks) == 51
|
||||
last = chunks[-1]
|
||||
assert last.stop_reason == StopReason.out_of_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
inference_impl = inference_settings["impl"]
|
||||
|
|
@ -142,7 +188,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
|
|||
inference_impl = inference_settings["impl"]
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
async for r in await inference_impl.chat_completion(
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
**inference_settings["common_params"],
|
||||
|
|
@ -213,7 +259,7 @@ async def test_chat_completion_with_tool_calling_streaming(
|
|||
|
||||
response = [
|
||||
r
|
||||
async for r in inference_impl.chat_completion(
|
||||
async for r in await inference_impl.chat_completion(
|
||||
messages=messages,
|
||||
tools=[sample_tool_definition],
|
||||
stream=True,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ providers:
|
|||
- provider_id: test-faiss
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-chroma
|
||||
provider_type: remote::chroma
|
||||
- provider_id: test-chromadb
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
host: localhost
|
||||
port: 6001
|
||||
|
|
|
|||
|
|
@ -89,6 +89,30 @@ async def test_banks_list(memory_settings):
|
|||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(memory_settings):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
banks_impl = memory_settings["memory_banks_impl"]
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank_no_provider",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
||||
# register same memory bank with same id again will fail
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_settings, sample_documents):
|
||||
memory_impl = memory_settings["memory_impl"]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import yaml
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
|
||||
|
||||
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
||||
|
|
@ -36,7 +36,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
|||
providers=chosen,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls_with_routing(run_config)
|
||||
impls = await resolve_impls(run_config)
|
||||
|
||||
if "provider_data" in config_dict:
|
||||
provider_id = chosen[api.value][0].provider_id
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
|||
if params := request.sampling_params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
if attr == "max_tokens":
|
||||
options["num_predict"] = getattr(params, attr)
|
||||
options[attr] = getattr(params, attr)
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
|
|
@ -49,27 +51,35 @@ def text_from_choice(choice) -> str:
|
|||
return choice.text
|
||||
|
||||
|
||||
def get_stop_reason(finish_reason: str) -> StopReason:
|
||||
if finish_reason in ["stop", "eos"]:
|
||||
return StopReason.end_of_turn
|
||||
elif finish_reason == "eom":
|
||||
return StopReason.end_of_message
|
||||
elif finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
|
||||
return StopReason.out_of_tokens
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
)
|
||||
|
||||
|
||||
def process_chat_completion_response(
|
||||
request: ChatCompletionRequest,
|
||||
response: OpenAICompatCompletionResponse,
|
||||
formatter: ChatFormat,
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
stop_reason = None
|
||||
if reason := choice.finish_reason:
|
||||
if reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif reason == "eom":
|
||||
stop_reason = StopReason.end_of_message
|
||||
elif reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
completion_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), stop_reason
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
|
|
@ -77,10 +87,45 @@ def process_chat_completion_response(
|
|||
)
|
||||
|
||||
|
||||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
) -> AsyncGenerator:
|
||||
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
continue
|
||||
elif text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
continue
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
|
||||
async def process_chat_completion_stream_response(
|
||||
request: ChatCompletionRequest,
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||
formatter: ChatFormat,
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
) -> AsyncGenerator:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
|
|||
|
|
@ -23,6 +23,13 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ def severity(levelname: str) -> LogSeverity:
|
|||
elif levelname == "INFO":
|
||||
return LogSeverity.INFO
|
||||
elif levelname == "WARNING":
|
||||
return LogSeverity.WARNING
|
||||
return LogSeverity.WARN
|
||||
elif levelname == "ERROR":
|
||||
return LogSeverity.ERROR
|
||||
elif levelname == "CRITICAL":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue