Merge branch 'meta-llama:main' into qdrant

This commit is contained in:
Anush 2024-10-22 21:45:31 +05:30 committed by GitHub
commit 1575578446
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 3310 additions and 722 deletions

View file

@ -421,10 +421,8 @@ class Agents(Protocol):
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -67,14 +67,14 @@ class AgentsClient(Agents):
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
def create_agent_turn(
async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
return await self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
@ -126,7 +126,7 @@ async def _run_agent(
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn(
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,

View file

@ -180,5 +180,5 @@ class EventLogger:
color="cyan",
)
preivous_event_type = event_type
previous_event_type = event_type
previous_step_type = step_type

View file

@ -42,10 +42,10 @@ class InferenceClient(Inference):
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -139,7 +139,8 @@ async def run_main(
else:
logprobs_config = None
iterator = client.chat_completion(
assert stream, "Non streaming not supported here"
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,

View file

@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Completion response."""
completion_message: CompletionMessage
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
completion_message_batch: List[CompletionMessage]
batch: List[CompletionResponse]
@json_schema_type
@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
batch: List[ChatCompletionResponse]
@json_schema_type
@ -181,10 +182,8 @@ class ModelStore(Protocol):
class Inference(Protocol):
model_store: ModelStore
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -196,7 +195,7 @@ class Inference(Protocol):
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],

View file

@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
# register memory bank for the first time
response = await client.register_memory_bank(
VectorMemoryBankDef(
identifier="test_bank2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
)
cprint(f"register_memory_bank response={response}", "blue")
# list again after registering
response = await client.list_memory_banks()
cprint(f"list_memory_banks response={response}", "green")
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))

View file

@ -152,27 +152,29 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
parser.error("Please provide a model id")
return
prompt_guard = prompt_guard_model_sku()
if args.model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(args.model_id)
if model is None:
parser.error(f"Model {args.model_id} not found")
return
info = llama_meta_net_info(model)
# Check if model_id is a comma-separated list
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url
if not meta_url:
meta_url = input(
"Please provide the signed URL you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
prompt_guard = prompt_guard_model_sku()
for model_id in model_ids:
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
assert meta_url is not None and "llamameta.net" in meta_url
_meta_download(model, meta_url, info)
assert "llamameta.net" in meta_url
_meta_download(model, meta_url, info)
class ModelEntry(BaseModel):

View file

@ -13,7 +13,7 @@ from functools import lru_cache
from pathlib import Path
TEMPLATES_PATH = (
Path(os.path.relpath(__file__)).parent.parent.parent / "distribution" / "templates"
Path(os.path.relpath(__file__)).parent.parent.parent.parent / "distributions"
)

View file

@ -15,7 +15,7 @@ special_pip_deps="$6"
set -euo pipefail
build_name="$1"
image_name="llamastack-$build_name"
image_name="distribution-$build_name"
docker_base=$2
build_file_path=$3
host_build_dir=$4

View file

@ -55,7 +55,7 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order

View file

@ -70,7 +70,7 @@ class InferenceRouter(Inference):
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -93,11 +93,11 @@ class InferenceRouter(Inference):
)
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
return await provider.chat_completion(**params)
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -114,9 +114,9 @@ class InferenceRouter(Inference):
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
return (chunk async for chunk in await provider.completion(**params))
else:
return provider.completion(**params)
return await provider.completion(**params)
async def embeddings(
self,

View file

@ -87,8 +87,21 @@ class CommonRoutingTableImpl(RoutingTable):
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, MemoryBanksRoutingTable):
return ("Memory", "memory_bank")
else:
raise ValueError("Unknown routing table type")
if routing_key not in self.registry:
raise ValueError(f"`{routing_key}` not registered")
apiname, objname = apiname_object()
raise ValueError(
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
)
objs = self.registry[routing_key]
for obj in objs:
@ -110,10 +123,16 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object(self, obj: RoutableObjectWithProvider):
entries = self.registry.get(obj.identifier, [])
for entry in entries:
if entry.provider_id == obj.provider_id:
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
if entry.provider_id == obj.provider_id or not obj.provider_id:
print(
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
)
return
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")

View file

@ -37,7 +37,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
from llama_stack.distribution.resolver import resolve_impls
from .endpoints import get_all_api_endpoints
@ -203,7 +203,7 @@ async def maybe_await(value):
async def sse_generator(event_gen):
try:
async for item in event_gen:
async for item in await event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
@ -276,7 +276,7 @@ def main(
app = FastAPI()
impls = asyncio.run(resolve_impls_with_routing(config))
impls = asyncio.run(resolve_impls(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -1,15 +0,0 @@
name: local-cpu
distribution_spec:
description: remote inference + local safety/agents/memory
docker_image: null
providers:
inference:
- remote::ollama
- remote::tgi
- remote::together
- remote::fireworks
safety: meta-reference
agents: meta-reference
memory: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -1,42 +0,0 @@
version: '2'
built_at: '2024-10-08T17:42:07.505267'
image_name: local-cpu
docker_image: local-cpu
conda_env: null
apis:
- agents
- inference
- models
- memory
- safety
- shields
- memory_banks
providers:
inference:
- provider_id: remote::ollama
provider_type: remote::ollama
config:
host: localhost
port: 6000
safety:
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
memory:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}

View file

@ -1,11 +0,0 @@
name: local-gpu
distribution_spec:
description: local meta reference
docker_image: null
providers:
inference: meta-reference
safety: meta-reference
agents: meta-reference
memory: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -1,45 +0,0 @@
version: '2'
built_at: '2024-10-08T17:42:33.690666'
image_name: local-gpu
docker_image: local-gpu
conda_env: null
apis:
- memory
- inference
- agents
- shields
- safety
- models
- memory_banks
providers:
inference:
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
safety:
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
memory:
- provider_id: meta-reference
provider_type: meta-reference
config: {}
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}

View file

@ -1,10 +0,0 @@
name: local-bedrock-conda-example
distribution_spec:
description: Use Amazon Bedrock APIs.
providers:
inference: remote::bedrock
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local
distribution_spec:
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-databricks
distribution_spec:
description: Use Databricks for running LLM inference
providers:
inference: remote::databricks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-fireworks
distribution_spec:
description: Use Fireworks.ai for running LLM inference
providers:
inference: remote::fireworks
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-hf-endpoint
distribution_spec:
description: "Like local, but use Hugging Face Inference Endpoints for running LLM inference.\nSee https://hf.co/docs/api-endpoints."
providers:
inference: remote::hf::endpoint
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-hf-serverless
distribution_spec:
description: "Like local, but use Hugging Face Inference API (serverless) for running LLM inference.\nSee https://hf.co/docs/api-inference."
providers:
inference: remote::hf::serverless
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-ollama
distribution_spec:
description: Like local, but use ollama for running LLM inference
providers:
inference: remote::ollama
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-tgi
distribution_spec:
description: Like local, but use a TGI server for running LLM inference.
providers:
inference: remote::tgi
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-together
distribution_spec:
description: Use Together.ai for running LLM inference
providers:
inference: remote::together
memory: meta-reference
safety: remote::together
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -1,10 +0,0 @@
name: local-vllm
distribution_spec:
description: Like local, but use vLLM for running LLM inference
providers:
inference: vllm
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: conda

View file

@ -47,7 +47,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
self.client.close()
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -283,7 +283,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
return tool_config
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],

View file

@ -7,10 +7,11 @@
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl
return impl

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -19,4 +18,4 @@ class DatabricksImplConfig(BaseModel):
api_token: str = Field(
default=None,
description="The Databricks API token",
)
)

View file

@ -48,10 +48,17 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -77,14 +84,14 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
@ -98,7 +105,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -61,7 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -87,14 +87,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
@ -103,7 +103,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -23,9 +23,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
OLLAMA_SUPPORTED_MODELS = {
@ -33,7 +36,8 @@ OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
"Llama-Guard-3-8B": "llama-guard3:8b",
"Llama-Guard-3-1B": "llama-guard3:1b",
}
@ -84,7 +88,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return ret
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -92,9 +96,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def chat_completion(
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": completion_request_to_prompt(request, self.formatter),
"options": sampling_options,
"raw": True,
"stream": request.stream,
}
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -118,7 +179,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
@ -143,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
@ -163,7 +224,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -66,7 +66,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -76,7 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,7 +101,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
@ -116,7 +116,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
@ -135,7 +135,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -64,7 +64,7 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator:
raise NotImplementedError()
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -101,14 +101,14 @@ class TogetherInferenceAdapter(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Together
@ -123,7 +123,7 @@ class TogetherInferenceAdapter(
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import VLLMImplConfig
from .vllm import VLLMInferenceAdapter
async def get_adapter_impl(config: VLLMImplConfig, _deps):
assert isinstance(config, VLLMImplConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class VLLMImplConfig(BaseModel):
url: Optional[str] = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
)
api_token: Optional[str] = Field(
default=None,
description="The API token",
)

View file

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMImplConfig
VLLM_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMImplConfig) -> None:
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Model registration is not supported for vLLM models")
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(identifier=model.id, llama_model=model.id)
for model in self.client.models.list()
]
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request, self.client)
else:
return self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(request, r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
# TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async
# generator so this wrapper is not necessary?
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -424,7 +424,7 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),

View file

@ -105,7 +105,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id,
)
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,

View file

@ -17,13 +17,22 @@ from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceInferenceConfig(BaseModel):
model: str = Field(
default="Llama3.1-8B-Instruct",
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None
max_seq_len: int = 4096
max_batch_size: int = 1
# when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
# (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:

View file

@ -23,11 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@ -38,7 +33,11 @@ from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
@ -98,7 +97,10 @@ class Llama:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = model_checkpoint_dir(model)
if config.checkpoint_dir:
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -119,9 +121,7 @@ class Llama:
**params,
)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
@ -138,7 +138,7 @@ class Llama:
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
model = convert_to_quantized_model(model, config, ckpt_dir)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -170,14 +170,16 @@ class Llama:
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
print_input_tokens: bool = False,
) -> Generator:
params = self.model.params
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
@ -228,8 +230,7 @@ class Llama:
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
@ -295,15 +296,12 @@ class Llama:
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
request: CompletionRequest,
) -> Generator:
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -311,26 +309,25 @@ class Llama:
):
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
echo=False,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@ -341,12 +338,12 @@ class Llama:
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
request.tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
logprobs=bool(request.logprobs),
include_stop_token=True,
)

View file

@ -13,11 +13,9 @@ from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
# there's a single model parallel process running serving the model. for now,
@ -36,8 +34,11 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
print(f"Loading model `{self.model.descriptor()}`")
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(self.config)
self.generator.start()
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
@ -51,9 +52,21 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
]
async def shutdown(self) -> None:
self.generator.stop()
if self.config.create_distributed_process_group:
self.generator.stop()
def completion(
def check_model(self, request) -> None:
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -61,9 +74,114 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
def chat_completion(
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token in tokenizer.stop_tokens:
# not quite right semantically
stop_reason = StopReason.end_of_turn
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion(
self,
model: str,
messages: List[Message],
@ -88,43 +206,26 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
model = resolve_model(request.model)
if model is None:
raise RuntimeError(
f"Unknown model: {request.model}, Run `llama model list`"
)
elif model.descriptor() != self.model.descriptor():
raise RuntimeError(
f"Model mismatch: {request.model} != {self.model.descriptor()}"
)
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
def impl():
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
@ -154,12 +255,16 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -172,14 +277,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
@ -272,6 +370,14 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def embeddings(
self,
model: str,

View file

@ -7,16 +7,17 @@
import os
from copy import deepcopy
from functools import partial
from typing import Generator, List, Optional
from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
from .parallel_utils import ModelParallelProcessGroup
class ModelRunner:
@ -24,15 +25,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, task: InferenceArgs):
return self.llama.chat_completion(
task.messages,
task.temperature,
task.top_p,
task.max_gen_len,
task.logprobs,
task.tool_prompt_format,
)
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceInferenceConfig):
@ -77,23 +76,18 @@ class LlamaModelParallelGenerator:
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
def chat_completion(
def completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
request: CompletionRequest,
) -> Generator:
req_obj = InferenceArgs(
messages=deepcopy(messages),
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs or False,
tool_prompt_format=tool_prompt_format,
)
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen
def chat_completion(
self,
request: ChatCompletionRequest,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen

View file

@ -4,6 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import multiprocessing
import os
@ -11,10 +17,9 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, List, Literal, Optional, Union
from typing import Callable, Generator, Literal, Optional, Union
import torch
import zmq
from fairscale.nn.model_parallel.initialize import (
@ -23,25 +28,16 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult
class InferenceArgs(BaseModel):
messages: List[Message]
temperature: float
top_p: float
max_gen_len: int
logprobs: bool
tool_prompt_format: ToolPromptFormat
class ProcessingMessageName(str, Enum):
ready_request = "ready_request"
ready_response = "ready_response"
@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: InferenceArgs
task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel):
@ -349,11 +345,13 @@ class ModelParallelProcessGroup:
self.process.join()
self.started = False
def run_inference(self, inference_args: InferenceArgs) -> Generator:
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
) -> Generator:
assert not self.running, "inference already running"
self.running = True
self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
self.request_socket.send(encode_msg(TaskRequest(task=req)))
try:
while True:
obj_json = self.request_socket.recv()

View file

@ -13,9 +13,10 @@ from typing import Optional
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import Tensor
@ -39,6 +40,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
@ -49,12 +51,14 @@ def convert_to_quantized_model(
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
checkpoint = config.checkpoint_config.checkpoint
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path

View file

@ -170,7 +170,7 @@ class LlamaGuardShield(ShieldBase):
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages
@ -184,7 +184,7 @@ class LlamaGuardShield(ShieldBase):
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
async for chunk in await self.inference_api.chat_completion(
model=self.model,
messages=[shield_input_message],
stream=True,

View file

@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig

View file

@ -134,7 +134,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if self.engine:
self.engine.shutdown_background_loop()
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -152,7 +152,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
logprobs=logprobs,
)
def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[Message],
@ -189,7 +189,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if stream:
return self._stream_chat_completion(request, results_generator)
else:
return self._nonstream_chat_completion(request, results_generator)
return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
@ -207,7 +207,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(request, response, self.formatter)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
@ -229,7 +229,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter
stream, self.formatter
):
yield chunk

View file

@ -55,11 +55,20 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama"],
pip_packages=["ollama", "aiohttp"],
config_class="llama_stack.providers.adapters.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.adapters.inference.ollama",
),
),
# remote_provider_spec(
# api=Api.inference,
# adapter=AdapterSpec(
# adapter_type="vllm",
# pip_packages=["openai"],
# module="llama_stack.providers.adapters.inference.vllm",
# config_class="llama_stack.providers.adapters.inference.vllm.VLLMImplConfig",
# ),
# ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -31,4 +31,4 @@ providers:
persistence_store:
namespace: null
type: sqlite
db_path: /Users/ashwin/.llama/runtime/kvstore.db
db_path: ~/.llama/runtime/kvstore.db

View file

@ -64,6 +64,24 @@ def search_query_messages():
]
@pytest.fixture
def attachment_message():
return [
UserMessage(
content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
),
]
@pytest.fixture
def query_attachment_messages():
return [
UserMessage(
content="What are the top 5 topics that were explained? Only list succinct bullet points."
),
]
@pytest.mark.asyncio
async def test_create_agent_turn(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
@ -98,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages):
assert len(final_event.turn.output_message.content) > 0
@pytest.mark.asyncio
async def test_rag_agent_as_attachments(
agents_settings, attachment_message, query_attachment_messages
):
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
agents_impl = agents_settings["impl"]
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[
MemoryToolDefinition(
memory_bank_configs=[],
query_generator_config={
"type": "default",
"sep": " ",
},
max_tokens_in_context=4096,
max_chunks=10,
),
],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=attachment_message,
attachments=attachments,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
# Create a second turn querying the agent
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=query_attachment_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@pytest.mark.asyncio
async def test_create_agent_turn_with_brave_search(
agents_settings, search_query_messages
@ -169,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
)
turn_response = [
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0

View file

@ -4,6 +4,10 @@ providers:
config:
host: localhost
port: 11434
- provider_id: meta-reference
provider_type: meta-reference
config:
model: Llama3.2-1B-Instruct
- provider_id: test-tgi
provider_type: remote::tgi
config:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import itertools
import os
import pytest
import pytest_asyncio
@ -50,14 +51,17 @@ def get_expected_stop_reason(model: str):
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
if "MODEL_IDS" not in os.environ:
MODEL_IDS = [Llama_8B, Llama_3B]
else:
MODEL_IDS = os.environ["MODEL_IDS"].split(",")
# This is going to create multiple Stack impls without tearing down the previous one
# Fix that!
@pytest_asyncio.fixture(
scope="session",
params=[
{"model": Llama_8B},
{"model": Llama_3B},
],
params=[{"model": m} for m in MODEL_IDS],
ids=lambda d: d["model"],
)
async def inference_settings(request):
@ -122,6 +126,48 @@ async def test_model_list(inference_settings):
assert model_def.identifier == params["model"]
@pytest.mark.asyncio
async def test_completion(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Roses are red,",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, CompletionResponse)
assert "violets are blue" in response.content
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert len(chunks) == 51
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
@ -142,7 +188,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
async for r in inference_impl.chat_completion(
async for r in await inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
@ -213,7 +259,7 @@ async def test_chat_completion_with_tool_calling_streaming(
response = [
r
async for r in inference_impl.chat_completion(
async for r in await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,

View file

@ -2,8 +2,8 @@ providers:
- provider_id: test-faiss
provider_type: meta-reference
config: {}
- provider_id: test-chroma
provider_type: remote::chroma
- provider_id: test-chromadb
provider_type: remote::chromadb
config:
host: localhost
port: 6001

View file

@ -89,6 +89,30 @@ async def test_banks_list(memory_settings):
assert len(response) == 0
@pytest.mark.asyncio
async def test_banks_register(memory_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
banks_impl = memory_settings["memory_banks_impl"]
bank = VectorMemoryBankDef(
identifier="test_bank_no_provider",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await banks_impl.register_memory_bank(bank)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
# register same memory bank with same id again will fail
await banks_impl.register_memory_bank(bank)
response = await banks_impl.list_memory_banks()
assert isinstance(response, list)
assert len(response) == 1
@pytest.mark.asyncio
async def test_query_documents(memory_settings, sample_documents):
memory_impl = memory_settings["memory_impl"]

View file

@ -14,7 +14,7 @@ import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
from llama_stack.distribution.resolver import resolve_impls
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
@ -36,7 +36,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls_with_routing(run_config)
impls = await resolve_impls(run_config)
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id

View file

@ -34,6 +34,8 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
if params := request.sampling_params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
if attr == "max_tokens":
options["num_predict"] = getattr(params, attr)
options[attr] = getattr(params, attr)
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
@ -49,27 +51,35 @@ def text_from_choice(choice) -> str:
return choice.text
def get_stop_reason(finish_reason: str) -> StopReason:
if finish_reason in ["stop", "eos"]:
return StopReason.end_of_turn
elif finish_reason == "eom":
return StopReason.end_of_message
elif finish_reason == "length":
return StopReason.out_of_tokens
return StopReason.out_of_tokens
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
)
def process_chat_completion_response(
request: ChatCompletionRequest,
response: OpenAICompatCompletionResponse,
formatter: ChatFormat,
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> ChatCompletionResponse:
choice = response.choices[0]
stop_reason = None
if reason := choice.finish_reason:
if reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif reason == "eom":
stop_reason = StopReason.end_of_message
elif reason == "length":
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
completion_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), stop_reason
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
return ChatCompletionResponse(
completion_message=completion_message,
@ -77,10 +87,45 @@ def process_chat_completion_response(
)
async def process_completion_stream_response(
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
stop_reason = None
async for chunk in stream:
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
)
yield CompletionResponseStreamChunk(
delta="",
stop_reason=stop_reason,
)
async def process_chat_completion_stream_response(
request: ChatCompletionRequest,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
formatter: ChatFormat,
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(

View file

@ -23,6 +23,13 @@ from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
model_input = formatter.encode_content(request.content)
return formatter.tokenizer.decode(model_input.tokens)
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:

View file

@ -152,7 +152,7 @@ def severity(levelname: str) -> LogSeverity:
elif levelname == "INFO":
return LogSeverity.INFO
elif levelname == "WARNING":
return LogSeverity.WARNING
return LogSeverity.WARN
elif levelname == "ERROR":
return LogSeverity.ERROR
elif levelname == "CRITICAL":