mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 01:51:59 +00:00
Merge branch 'main' into register_custom_model
This commit is contained in:
commit
afb792b9c1
69 changed files with 8875 additions and 890 deletions
|
|
@ -18,7 +18,7 @@ from typing import (
|
|||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
|
|
@ -442,6 +442,37 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
Union[
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
"""A message from the user in an OpenAI-compatible chat completion request.
|
||||
|
|
@ -452,7 +483,7 @@ class OpenAIUserMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["user"] = "user"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -466,10 +497,24 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
index: Optional[int] = None
|
||||
id: Optional[str] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIAssistantMessageParam(BaseModel):
|
||||
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
|
||||
|
|
@ -477,13 +522,13 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
:param role: Must be "assistant" to identify this as the model's response
|
||||
:param content: The content of the model's response
|
||||
:param name: (Optional) The name of the assistant message participant.
|
||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
||||
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -497,7 +542,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -510,7 +555,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -527,6 +572,46 @@ OpenAIMessageParam = Annotated[
|
|||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatText(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
strict: Optional[bool] = None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||
type: Literal["json_schema"] = "json_schema"
|
||||
json_schema: OpenAIJSONSchema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||
type: Literal["json_object"] = "json_object"
|
||||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
Union[
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAITopLogProb(BaseModel):
|
||||
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||
|
|
@ -561,22 +646,54 @@ class OpenAITokenLogProb(BaseModel):
|
|||
class OpenAIChoiceLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
|
||||
|
||||
:content: (Optional) The log probabilities for the tokens in the message
|
||||
:refusal: (Optional) The log probabilities for the tokens in the message
|
||||
:param content: (Optional) The log probabilities for the tokens in the message
|
||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
content: Optional[List[OpenAITokenLogProb]] = None
|
||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoiceDelta(BaseModel):
|
||||
"""A delta from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param content: (Optional) The content of the delta
|
||||
:param refusal: (Optional) The refusal of the delta
|
||||
:param role: (Optional) The role of the delta
|
||||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChunkChoice(BaseModel):
|
||||
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param delta: The delta from the chunk
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
delta: OpenAIChoiceDelta
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoice(BaseModel):
|
||||
"""A choice from an OpenAI-compatible chat completion response.
|
||||
|
||||
:param message: The message from the model
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:index: The index of the choice
|
||||
:logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
message: OpenAIMessageParam
|
||||
|
|
@ -603,6 +720,24 @@ class OpenAIChatCompletion(BaseModel):
|
|||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionChunk(BaseModel):
|
||||
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
|
||||
|
||||
:param id: The ID of the chat completion
|
||||
:param choices: List of choices
|
||||
:param object: The object type, which will be "chat.completion.chunk"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChunkChoice]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAICompletionLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
|
||||
|
|
@ -872,7 +1007,7 @@ class Inference(Protocol):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -883,7 +1018,7 @@ class Inference(Protocol):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
status: str
|
||||
# TODO: add a provider level status
|
||||
status: HealthStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
|
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
|
|||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -89,6 +89,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
providers = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider in providers_for_api:
|
||||
providers.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=providers,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
if not args.image_type:
|
||||
cprint(
|
||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
|
||||
elif not args.config and not args.template:
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
|
|
|
|||
|
|
@ -75,6 +75,12 @@ the build. If not specified, currently active environment will be used if found.
|
|||
default=False,
|
||||
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
|
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
|
|||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
||||
return HealthInfo(status=HealthStatus.OK)
|
||||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
|
|
|||
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
|
|||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
|
|||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
for p in providers:
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
),
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
|
|
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
|
|||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||
timeout = 1.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||
if not hasattr(impl, "__provider_spec__"):
|
||||
return None
|
||||
api_name = impl.__provider_spec__.api.name
|
||||
if not hasattr(impl, "health"):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||
)
|
||||
|
||||
# Create tasks for all providers
|
||||
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||
|
||||
# Wait for all health checks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Organize results by API and provider ID
|
||||
for result in results:
|
||||
if result is None: # Skip special implementations
|
||||
continue
|
||||
api_name, health_response = result
|
||||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
|
|||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
|
|
@ -230,46 +229,6 @@ def sort_providers_by_deps(
|
|||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
|
|
@ -37,7 +38,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
|
|
@ -60,7 +67,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
|
@ -530,7 +537,7 @@ class InferenceRouter(Inference):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -541,7 +548,7 @@ class InferenceRouter(Inference):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
|
|
@ -580,6 +587,29 @@ class InferenceRouter(Inference):
|
|||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_chat_completion(**params)
|
||||
|
||||
async def health(self) -> Dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 0.5
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
if not hasattr(impl, "health"):
|
||||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except asyncio.TimeoutError:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
|
|||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
|
@ -229,15 +229,30 @@ class TracingMiddleware:
|
|||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
|
|
@ -388,7 +403,12 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
|
@ -119,26 +121,6 @@ class EnvVarError(Exception):
|
|||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
|
|
@ -215,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
impls: Dictionary of API implementations
|
||||
run_config: Stack run configuration
|
||||
"""
|
||||
inspect_impl = DistributionInspectImpl(
|
||||
DistributionInspectConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.inspect] = inspect_impl
|
||||
|
||||
providers_impl = ProviderImpl(
|
||||
ProviderImplConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
|
|
@ -222,6 +224,10 @@ async def construct_stack(
|
|||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,17 @@ def tool_chat_page():
|
|||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||
st.json(active_tool_list)
|
||||
|
||||
st.subheader("Chat Configurations")
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
|
|
@ -63,9 +74,7 @@ def tool_chat_page():
|
|||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
},
|
||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
|
|
|||
30
llama_stack/distribution/utils/config.py
Normal file
30
llama_stack/distribution/utils/config.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_value(v: Any) -> Any:
|
||||
if isinstance(v, dict):
|
||||
return _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
return [_redact_value(i) for i in v]
|
||||
return v
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = _redact_value(v)
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
|
@ -204,7 +204,9 @@ class ToolUtils:
|
|||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
if ("type" in response and response["type"] == "function") or (
|
||||
"name" in response and "parameters" in response
|
||||
):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -201,3 +202,12 @@ def remote_provider_spec(
|
|||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
OK = "OK"
|
||||
ERROR = "Error"
|
||||
NOT_IMPLEMENTED = "Not Implemented"
|
||||
|
||||
|
||||
HealthResponse = dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
|
|
@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
|
@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
|||
|
|
@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
|
@ -176,8 +176,8 @@ def _convert_sampling_params(
|
|||
|
||||
class VLLMInferenceImpl(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
|
|
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
|
|||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
|
|
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
|
|
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
self.checkpoints_dict[job_uuid] = checkpoints
|
||||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
raise
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("Lora finetuning completed")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return post_training_job
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
|
|
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
return self.jobs.get(job_uuid, None)
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
if job_uuid in self.checkpoints_dict:
|
||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
|
||||
return None
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_strategy_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
|
|||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
|
|||
class CerebrasInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -61,8 +61,8 @@ model_entries = [
|
|||
class DatabricksInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
|
@ -32,13 +32,20 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
|
|
@ -301,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
|
@ -320,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -336,7 +349,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -347,10 +360,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
@ -374,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
||||
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
|
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Groq does not support json_schema response format, so we need to convert it to json_object
|
||||
if response_format and response_format.type == "json_schema":
|
||||
response_format.type = "json_object"
|
||||
schema = response_format.json_schema.get("schema", {})
|
||||
response_format.json_schema = None
|
||||
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
||||
if messages and messages[0].role == "system":
|
||||
messages[0].content = messages[0].content + json_instructions
|
||||
else:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
||||
|
||||
# Groq returns a 400 error if tools are provided but none are called
|
||||
# So, set tool_choice to "required" to attempt to force a call
|
||||
if tools and (not tool_choice or tool_choice == "auto"):
|
||||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Groq does not support streaming requests that set response_format
|
||||
fake_stream = False
|
||||
if stream and response_format:
|
||||
params["stream"] = False
|
||||
fake_stream = True
|
||||
|
||||
response = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
if fake_stream:
|
||||
chunk_choices = []
|
||||
for choice in response.choices:
|
||||
delta = OpenAIChoiceDelta(
|
||||
content=choice.message.content,
|
||||
role=choice.message.role,
|
||||
tool_calls=choice.message.tool_calls,
|
||||
)
|
||||
chunk_choice = OpenAIChunkChoice(
|
||||
delta=delta,
|
||||
finish_reason=choice.finish_reason,
|
||||
index=choice.index,
|
||||
logprobs=None,
|
||||
)
|
||||
chunk_choices.append(chunk_choice)
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response.id,
|
||||
choices=chunk_choices,
|
||||
object="chat.completion.chunk",
|
||||
created=response.created,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _fake_stream_generator():
|
||||
yield chunk
|
||||
|
||||
return _fake_stream_generator()
|
||||
else:
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -39,8 +39,16 @@ MODEL_ENTRIES = [
|
|||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -34,15 +34,18 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -335,7 +338,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -346,7 +349,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
|
@ -39,10 +39,20 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -87,8 +97,19 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
await self.health()
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
This method is used by initialize() and the Provider API to verify that the service is running
|
||||
correctly.
|
||||
Returns:
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except httpx.ConnectError as e:
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
|
|
@ -322,6 +343,12 @@ class OllamaInferenceAdapter(
|
|||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
|
||||
if model.provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
|
@ -393,7 +420,7 @@ class OllamaInferenceAdapter(
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -404,7 +431,7 @@ class OllamaInferenceAdapter(
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
|
|
@ -26,7 +26,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -266,7 +272,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -277,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
|
|||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
|
|
@ -57,8 +57,8 @@ from .models import MODEL_ENTRIES
|
|||
class SambaNovaInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: SambaNovaImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
|
|||
|
||||
class _HfAdapter(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
client: AsyncInferenceClient
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from together import AsyncTogether
|
||||
|
|
@ -31,7 +31,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -315,7 +321,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -326,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
@ -353,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# together.ai sometimes adds usage data to the stream, even if include_usage is False
|
||||
# This causes an unexpected final chunk with empty choices array to be sent
|
||||
# to clients that may not handle it gracefully.
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
|
@ -45,7 +45,12 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
|
@ -369,7 +374,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
options["max_tokens"] = self.config.max_tokens
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
|
||||
if isinstance(request, ChatCompletionRequest) and request.tools:
|
||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
|
|
@ -487,7 +493,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -498,7 +504,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -270,7 +276,7 @@ class LiteLLMOpenAIMixin(
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
|
@ -292,7 +298,7 @@ class LiteLLMOpenAIMixin(
|
|||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
return litellm.text_completion(**params)
|
||||
return await litellm.atext_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -308,7 +314,7 @@ class LiteLLMOpenAIMixin(
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -319,8 +325,8 @@ class LiteLLMOpenAIMixin(
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
|
|
@ -346,7 +352,7 @@ class LiteLLMOpenAIMixin(
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return litellm.completion(**params)
|
||||
return await litellm.acompletion(**params)
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
|
|
@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
|
|||
from openai.types.chat.chat_completion import (
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as OpenAIChatCompletionChunkChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDelta as OpenAIChoiceDelta,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
|
|
@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
|
|
@ -85,12 +98,24 @@ from llama_stack.apis.inference import (
|
|||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
|
||||
from llama_stack.apis.inference.inference import (
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
|
|
@ -751,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
|
||||
"""
|
||||
Convert a StopReason to an OpenAI chat completion finish_reason.
|
||||
"""
|
||||
return {
|
||||
StopReason.end_of_turn: "stop",
|
||||
StopReason.end_of_message: "tool_calls",
|
||||
StopReason.out_of_tokens: "length",
|
||||
}.get(stop_reason, "stop")
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
|
@ -776,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
|||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
tool_config.tool_choice = tool_choice
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
|
||||
lls_tools = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
for tool in tools:
|
||||
tool_fn = tool.get("function", {})
|
||||
tool_name = tool_fn.get("name", None)
|
||||
tool_desc = tool_fn.get("description", None)
|
||||
|
||||
tool_params = tool_fn.get("parameters", None)
|
||||
lls_tool_params = {}
|
||||
if tool_params is not None:
|
||||
tool_param_properties = tool_params.get("properties", {})
|
||||
for tool_param_key, tool_param_value in tool_param_properties.items():
|
||||
tool_param_def = ToolParamDefinition(
|
||||
param_type=tool_param_value.get("type", None),
|
||||
description=tool_param_value.get("description", None),
|
||||
)
|
||||
lls_tool_params[tool_param_key] = tool_param_def
|
||||
|
||||
lls_tool = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool_desc,
|
||||
parameters=lls_tool_params,
|
||||
)
|
||||
lls_tools.append(lls_tool)
|
||||
return lls_tools
|
||||
|
||||
|
||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format = dict(response_format)
|
||||
if response_format.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema",
|
||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
|
|
@ -871,6 +957,40 @@ def _convert_openai_sampling_params(
|
|||
return sampling_params
|
||||
|
||||
|
||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
||||
lls_messages = []
|
||||
for message in messages:
|
||||
lls_message = dict(message)
|
||||
|
||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
||||
if tool_call_id:
|
||||
lls_message["call_id"] = tool_call_id
|
||||
|
||||
content = lls_message.get("content", None)
|
||||
if isinstance(content, list):
|
||||
lls_content = []
|
||||
for item in content:
|
||||
# items can either by pydantic models or dicts here...
|
||||
item = dict(item)
|
||||
if item.get("type", "") == "image_url":
|
||||
lls_item = ImageContentItem(
|
||||
type="image",
|
||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
||||
)
|
||||
elif item.get("type", "") == "text":
|
||||
lls_item = TextContentItem(
|
||||
type="text",
|
||||
text=item.get("text", ""),
|
||||
)
|
||||
lls_content.append(lls_item)
|
||||
lls_message["content"] = lls_content
|
||||
lls_messages.append(lls_message)
|
||||
|
||||
return lls_messages
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
|
|
@ -1080,11 +1200,24 @@ async def convert_openai_chat_completion_stream(
|
|||
|
||||
|
||||
async def prepare_openai_completion_params(**params):
|
||||
completion_params = {k: v for k, v in params.items() if v is not None}
|
||||
async def _prepare_value(value: Any) -> Any:
|
||||
new_value = value
|
||||
if isinstance(value, list):
|
||||
new_value = [await _prepare_value(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
new_value = {k: await _prepare_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
new_value = value.model_dump(exclude_none=True)
|
||||
return new_value
|
||||
|
||||
completion_params = {}
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
completion_params[k] = await _prepare_value(v)
|
||||
return completion_params
|
||||
|
||||
|
||||
class OpenAICompletionUnsupportedMixin:
|
||||
class OpenAICompletionToLlamaStackMixin:
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1122,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
|
||||
choices = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
# and we may have multiple prompts, if batching was used
|
||||
|
||||
|
|
@ -1134,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
|
||||
index = len(choices)
|
||||
text = result.content
|
||||
finish_reason = _convert_openai_finish_reason(result.stop_reason)
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
|
||||
|
||||
choice = OpenAICompletionChoice(
|
||||
index=index,
|
||||
|
|
@ -1152,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionUnsupportedMixin:
|
||||
class OpenAIChatCompletionToLlamaStackMixin:
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1167,7 +1301,7 @@ class OpenAIChatCompletionUnsupportedMixin:
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -1178,5 +1312,103 @@ class OpenAIChatCompletionUnsupportedMixin:
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages = _convert_openai_request_messages(messages)
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
response = self.chat_completion(
|
||||
model_id=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
tool_config=tool_config,
|
||||
tools=tools,
|
||||
)
|
||||
outstanding_responses.append(response)
|
||||
|
||||
if stream:
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||
|
||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||
self, model, outstanding_responses
|
||||
)
|
||||
|
||||
async def _process_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
i = 0
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
|
||||
if isinstance(event.delta, TextDelta):
|
||||
text_delta = event.delta.text
|
||||
delta = OpenAIChoiceDelta(content=text_delta)
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
i = i + 1
|
||||
|
||||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")
|
||||
choices = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
completion_message = response.completion_message
|
||||
message = await convert_message_to_openai_dict_new(completion_message)
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
|
||||
|
||||
choice = OpenAIChatCompletionChoice(
|
||||
index=len(choices),
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
|
|
|||
265
llama_stack/providers/utils/scheduler.py
Normal file
265
llama_stack/providers/utils/scheduler.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scheduler")
|
||||
|
||||
|
||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
||||
class JobStatus(Enum):
|
||||
new = "new"
|
||||
scheduled = "scheduled"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
JobID: TypeAlias = str
|
||||
JobType: TypeAlias = str
|
||||
|
||||
|
||||
class JobArtifact(BaseModel):
|
||||
type: JobType
|
||||
name: str
|
||||
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
||||
uri: str | None = None
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
JobHandler = Callable[
|
||||
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
||||
]
|
||||
|
||||
|
||||
LogMessage: TypeAlias = Tuple[datetime, str]
|
||||
|
||||
|
||||
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||
|
||||
|
||||
class Job:
|
||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
||||
super().__init__()
|
||||
self.id = job_id
|
||||
self._type = job_type
|
||||
self._handler = handler
|
||||
self._artifacts: list[JobArtifact] = []
|
||||
self._logs: list[LogMessage] = []
|
||||
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler:
|
||||
return self._handler
|
||||
|
||||
@property
|
||||
def status(self) -> JobStatus:
|
||||
return self._state_transitions[-1][1]
|
||||
|
||||
@status.setter
|
||||
def status(self, status: JobStatus):
|
||||
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
||||
raise ValueError(f"Job is already in a completed state ({self.status})")
|
||||
if self.status == status:
|
||||
return
|
||||
self._state_transitions.append((datetime.now(timezone.utc), status))
|
||||
|
||||
@property
|
||||
def artifacts(self) -> list[JobArtifact]:
|
||||
return self._artifacts
|
||||
|
||||
def register_artifact(self, artifact: JobArtifact) -> None:
|
||||
self._artifacts.append(artifact)
|
||||
|
||||
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
||||
for date, s in reversed(self._state_transitions):
|
||||
if s in status:
|
||||
return date
|
||||
return None
|
||||
|
||||
@property
|
||||
def scheduled_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.scheduled])
|
||||
|
||||
@property
|
||||
def started_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.running])
|
||||
|
||||
@property
|
||||
def completed_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
||||
|
||||
@property
|
||||
def logs(self) -> list[LogMessage]:
|
||||
return self._logs[:]
|
||||
|
||||
def append_log(self, message: LogMessage) -> None:
|
||||
self._logs.append(message)
|
||||
|
||||
# TODO: implement
|
||||
def cancel(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _SchedulerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||
def __init__(self, timeout: int = 5):
|
||||
self._timeout = timeout
|
||||
self._loop = asyncio.new_event_loop()
|
||||
# There may be performance implications of using threads due to Python
|
||||
# GIL; may need to measure if it's a real problem though
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
# When stopping the loop, give tasks a chance to finish
|
||||
# TODO: should we explicitly inform jobs of pending stoppage?
|
||||
for task in asyncio.all_tasks(self._loop):
|
||||
self._loop.run_until_complete(task)
|
||||
self._loop.close()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
# TODO: decouple scheduling and running the job
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
async def do():
|
||||
try:
|
||||
job.status = JobStatus.running
|
||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Job {job.id} failed.")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
pass
|
||||
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
pass
|
||||
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_BACKENDS = {
|
||||
"naive": _NaiveSchedulerBackend,
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
||||
try:
|
||||
return _BACKENDS[backend]()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown backend {backend}") from e
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, backend: str = "naive"):
|
||||
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
||||
self._jobs: dict[JobID, Job] = {}
|
||||
self._backend = _get_backend_impl(backend)
|
||||
|
||||
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
||||
msg = (datetime.now(timezone.utc), message)
|
||||
# At least for the time being, until there's a better way to expose
|
||||
# logs to users, log messages on console
|
||||
logger.info(f"Job {job.id}: {message}")
|
||||
job.append_log(msg)
|
||||
self._backend.on_log_message_cb(job, msg)
|
||||
|
||||
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
job.status = status
|
||||
self._backend.on_status_change_cb(job, status)
|
||||
|
||||
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
job.register_artifact(artifact)
|
||||
self._backend.on_artifact_collected_cb(job, artifact)
|
||||
|
||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
||||
job = Job(type_, job_id, handler)
|
||||
if job.id in self._jobs:
|
||||
raise ValueError(f"Job {job.id} already exists")
|
||||
|
||||
self._jobs[job.id] = job
|
||||
job.status = JobStatus.scheduled
|
||||
self._backend.schedule(
|
||||
job,
|
||||
functools.partial(self._on_log_message_cb, job),
|
||||
functools.partial(self._on_status_change_cb, job),
|
||||
functools.partial(self._on_artifact_collected_cb, job),
|
||||
)
|
||||
|
||||
return job.id
|
||||
|
||||
def cancel(self, job_id: JobID) -> None:
|
||||
self.get_job(job_id).cancel()
|
||||
|
||||
def get_job(self, job_id: JobID) -> Job:
|
||||
try:
|
||||
return self._jobs[job_id]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Job {job_id} not found") from e
|
||||
|
||||
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
||||
jobs = list(self._jobs.values())
|
||||
if type_:
|
||||
jobs = [job for job in jobs if job._type == type_]
|
||||
return jobs
|
||||
|
||||
async def shutdown(self):
|
||||
# TODO: also cancel jobs once implemented
|
||||
await self._backend.shutdown()
|
||||
|
|
@ -386,6 +386,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
|
|
@ -396,6 +406,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
|||
|
|
@ -158,6 +158,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
|
|
@ -168,6 +178,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Setting up vLLM server
|
||||
|
||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
||||
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
|
|
@ -149,6 +149,55 @@ docker run \
|
|||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
### Setting up vLLM server on Intel GPU
|
||||
|
||||
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
|
||||
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
|
||||
|
||||
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
|
||||
|
||||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||
export ZE_AFFINITY_MASK=0
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
--ipc=host \
|
||||
intel/vllm:xpu \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--model $INFERENCE_MODEL \
|
||||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
export SAFETY_PORT=8081
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
export ZE_AFFINITY_MASK=1
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
--ipc=host \
|
||||
intel/vllm:xpu \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--model $SAFETY_MODEL \
|
||||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
## Running Llama Stack
|
||||
|
||||
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
|
|
|
|||
|
|
@ -474,6 +474,16 @@ models:
|
|||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
|
|
@ -484,6 +494,16 @@ models:
|
|||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: Meta-Llama-3.1-8B-Instruct
|
||||
provider_id: sambanova-openai-compat
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue