mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 06:21:59 +00:00
Merge branch 'main' into feat/litellm_sambanova_usage
This commit is contained in:
commit
488eb8f249
39 changed files with 2102 additions and 164 deletions
|
|
@ -18,7 +18,7 @@ from typing import (
|
|||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
|
|
@ -442,6 +442,37 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
Union[
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
"""A message from the user in an OpenAI-compatible chat completion request.
|
||||
|
|
@ -452,7 +483,7 @@ class OpenAIUserMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["user"] = "user"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -466,10 +497,24 @@ class OpenAISystemMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
index: Optional[int] = None
|
||||
id: Optional[str] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIAssistantMessageParam(BaseModel):
|
||||
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
|
||||
|
|
@ -477,13 +522,13 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
:param role: Must be "assistant" to identify this as the model's response
|
||||
:param content: The content of the model's response
|
||||
:param name: (Optional) The name of the assistant message participant.
|
||||
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
|
||||
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -497,7 +542,7 @@ class OpenAIToolMessageParam(BaseModel):
|
|||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -510,7 +555,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: InterleavedContent
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -527,6 +572,46 @@ OpenAIMessageParam = Annotated[
|
|||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatText(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
strict: Optional[bool] = None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||
type: Literal["json_schema"] = "json_schema"
|
||||
json_schema: OpenAIJSONSchema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||
type: Literal["json_object"] = "json_object"
|
||||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
Union[
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAITopLogProb(BaseModel):
|
||||
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||
|
|
@ -561,22 +646,54 @@ class OpenAITokenLogProb(BaseModel):
|
|||
class OpenAIChoiceLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
|
||||
|
||||
:content: (Optional) The log probabilities for the tokens in the message
|
||||
:refusal: (Optional) The log probabilities for the tokens in the message
|
||||
:param content: (Optional) The log probabilities for the tokens in the message
|
||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
content: Optional[List[OpenAITokenLogProb]] = None
|
||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoiceDelta(BaseModel):
|
||||
"""A delta from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param content: (Optional) The content of the delta
|
||||
:param refusal: (Optional) The refusal of the delta
|
||||
:param role: (Optional) The role of the delta
|
||||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChunkChoice(BaseModel):
|
||||
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param delta: The delta from the chunk
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
delta: OpenAIChoiceDelta
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoice(BaseModel):
|
||||
"""A choice from an OpenAI-compatible chat completion response.
|
||||
|
||||
:param message: The message from the model
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:index: The index of the choice
|
||||
:logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
message: OpenAIMessageParam
|
||||
|
|
@ -603,6 +720,24 @@ class OpenAIChatCompletion(BaseModel):
|
|||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionChunk(BaseModel):
|
||||
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
|
||||
|
||||
:param id: The ID of the chat completion
|
||||
:param choices: List of choices
|
||||
:param object: The object type, which will be "chat.completion.chunk"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChunkChoice]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAICompletionLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
|
||||
|
|
@ -872,7 +1007,7 @@ class Inference(Protocol):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -883,7 +1018,7 @@ class Inference(Protocol):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
|
|
|
|||
|
|
@ -38,7 +38,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
|
|
@ -531,7 +537,7 @@ class InferenceRouter(Inference):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -542,7 +548,7 @@ class InferenceRouter(Inference):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -204,7 +204,9 @@ class ToolUtils:
|
|||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
if ("type" in response and response["type"] == "function") or (
|
||||
"name" in response and "parameters" in response
|
||||
):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
|
|
|
|||
|
|
@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
|
|
@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
|
|||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
|
@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
|
|
|||
|
|
@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
|
@ -176,8 +176,8 @@ def _convert_sampling_params(
|
|||
|
||||
class VLLMInferenceImpl(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,13 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
|
|
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||
LoraFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
|
|||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
|
|
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
) -> PostTrainingJob:
|
||||
if job_uuid in self.jobs:
|
||||
raise ValueError(f"Job {job_uuid} already exists")
|
||||
|
||||
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
job_status_response = PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=JobStatus.scheduled,
|
||||
scheduled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
self.jobs[job_uuid] = job_status_response
|
||||
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
try:
|
||||
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting Lora finetuning")
|
||||
|
||||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config,
|
||||
job_uuid,
|
||||
|
|
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
|
|||
self.datasetio_api,
|
||||
self.datasets_api,
|
||||
)
|
||||
|
||||
job_status_response.status = JobStatus.in_progress
|
||||
job_status_response.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await recipe.setup()
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train()
|
||||
|
||||
self.checkpoints_dict[job_uuid] = checkpoints
|
||||
job_status_response.resources_allocated = resources_allocated
|
||||
job_status_response.checkpoints = checkpoints
|
||||
job_status_response.status = JobStatus.completed
|
||||
job_status_response.completed_at = datetime.now(timezone.utc)
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
except Exception:
|
||||
job_status_response.status = JobStatus.failed
|
||||
raise
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("Lora finetuning completed")
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return post_training_job
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
|
|
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
return self.jobs.get(job_uuid, None)
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
if job_uuid in self.checkpoints_dict:
|
||||
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
|
||||
return None
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_strategy_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
|
|||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
|
|||
class CerebrasInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
|
|
|
|||
|
|
@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -61,8 +61,8 @@ model_entries = [
|
|||
class DatabricksInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from fireworks.client import Fireworks
|
||||
from openai import AsyncOpenAI
|
||||
|
|
@ -32,13 +32,20 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
|
|
@ -301,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
|
@ -320,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -336,7 +349,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -347,10 +360,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
@ -374,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
||||
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,24 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
|
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Groq does not support json_schema response format, so we need to convert it to json_object
|
||||
if response_format and response_format.type == "json_schema":
|
||||
response_format.type = "json_object"
|
||||
schema = response_format.json_schema.get("schema", {})
|
||||
response_format.json_schema = None
|
||||
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
|
||||
if messages and messages[0].role == "system":
|
||||
messages[0].content = messages[0].content + json_instructions
|
||||
else:
|
||||
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
|
||||
|
||||
# Groq returns a 400 error if tools are provided but none are called
|
||||
# So, set tool_choice to "required" to attempt to force a call
|
||||
if tools and (not tool_choice or tool_choice == "auto"):
|
||||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Groq does not support streaming requests that set response_format
|
||||
fake_stream = False
|
||||
if stream and response_format:
|
||||
params["stream"] = False
|
||||
fake_stream = True
|
||||
|
||||
response = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
if fake_stream:
|
||||
chunk_choices = []
|
||||
for choice in response.choices:
|
||||
delta = OpenAIChoiceDelta(
|
||||
content=choice.message.content,
|
||||
role=choice.message.role,
|
||||
tool_calls=choice.message.tool_calls,
|
||||
)
|
||||
chunk_choice = OpenAIChunkChoice(
|
||||
delta=delta,
|
||||
finish_reason=choice.finish_reason,
|
||||
index=choice.index,
|
||||
logprobs=None,
|
||||
)
|
||||
chunk_choices.append(chunk_choice)
|
||||
chunk = OpenAIChatCompletionChunk(
|
||||
id=response.id,
|
||||
choices=chunk_choices,
|
||||
object="chat.completion.chunk",
|
||||
created=response.created,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
async def _fake_stream_generator():
|
||||
yield chunk
|
||||
|
||||
return _fake_stream_generator()
|
||||
else:
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -39,8 +39,16 @@ MODEL_ENTRIES = [
|
|||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -35,7 +35,13 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
|
@ -329,7 +335,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -340,7 +346,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
|
|
@ -39,7 +39,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
|
|
@ -337,6 +343,12 @@ class OllamaInferenceAdapter(
|
|||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
|
||||
if model.provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
|
@ -408,7 +420,7 @@ class OllamaInferenceAdapter(
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -419,7 +431,7 @@ class OllamaInferenceAdapter(
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = {
|
||||
k: v
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
|
|
@ -26,7 +26,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -266,7 +272,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -277,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
|
|||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: RunpodImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
|
|
@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
|
|||
|
||||
class _HfAdapter(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
client: AsyncInferenceClient
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from together import AsyncTogether
|
||||
|
|
@ -31,7 +31,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -315,7 +321,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -326,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
@ -353,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||
# together.ai sometimes adds usage data to the stream, even if include_usage is False
|
||||
# This causes an unexpected final chunk with empty choices array to be sent
|
||||
# to clients that may not handle it gracefully.
|
||||
include_usage = False
|
||||
if params.get("stream_options", None):
|
||||
include_usage = params["stream_options"].get("include_usage", False)
|
||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
seen_finish_reason = False
|
||||
async for chunk in stream:
|
||||
# Final usage chunk with no choices that the user didn't request, so discard
|
||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||
break
|
||||
yield chunk
|
||||
for choice in chunk.choices:
|
||||
if choice.finish_reason:
|
||||
seen_finish_reason = True
|
||||
break
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
|
@ -45,7 +45,12 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
|
|
@ -487,7 +492,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -498,7 +503,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -270,7 +276,7 @@ class LiteLLMOpenAIMixin(
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
|
|
@ -292,7 +298,7 @@ class LiteLLMOpenAIMixin(
|
|||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
return litellm.text_completion(**params)
|
||||
return await litellm.atext_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -308,7 +314,7 @@ class LiteLLMOpenAIMixin(
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -319,8 +325,8 @@ class LiteLLMOpenAIMixin(
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
model_obj = await self._get_model(model)
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
|
|
@ -346,7 +352,7 @@ class LiteLLMOpenAIMixin(
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return litellm.completion(**params)
|
||||
return await litellm.acompletion(**params)
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
|
|
@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
|
|||
from openai.types.chat.chat_completion import (
|
||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
Choice as OpenAIChatCompletionChunkChoice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDelta as OpenAIChoiceDelta,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||
ImageURL as OpenAIImageURL,
|
||||
)
|
||||
|
|
@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
TextContentItem,
|
||||
|
|
@ -85,12 +98,24 @@ from llama_stack.apis.inference import (
|
|||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
|
||||
from llama_stack.apis.inference.inference import (
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChoice as OpenAIChatCompletionChoice,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_content_to_url,
|
||||
|
|
@ -751,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
|
||||
"""
|
||||
Convert a StopReason to an OpenAI chat completion finish_reason.
|
||||
"""
|
||||
return {
|
||||
StopReason.end_of_turn: "stop",
|
||||
StopReason.end_of_message: "tool_calls",
|
||||
StopReason.out_of_tokens: "length",
|
||||
}.get(stop_reason, "stop")
|
||||
|
||||
|
||||
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||
"""
|
||||
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||
|
|
@ -776,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
|||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
|
||||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
tool_config.tool_choice = tool_choice
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
|
||||
lls_tools = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
for tool in tools:
|
||||
tool_fn = tool.get("function", {})
|
||||
tool_name = tool_fn.get("name", None)
|
||||
tool_desc = tool_fn.get("description", None)
|
||||
|
||||
tool_params = tool_fn.get("parameters", None)
|
||||
lls_tool_params = {}
|
||||
if tool_params is not None:
|
||||
tool_param_properties = tool_params.get("properties", {})
|
||||
for tool_param_key, tool_param_value in tool_param_properties.items():
|
||||
tool_param_def = ToolParamDefinition(
|
||||
param_type=tool_param_value.get("type", None),
|
||||
description=tool_param_value.get("description", None),
|
||||
)
|
||||
lls_tool_params[tool_param_key] = tool_param_def
|
||||
|
||||
lls_tool = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool_desc,
|
||||
parameters=lls_tool_params,
|
||||
)
|
||||
lls_tools.append(lls_tool)
|
||||
return lls_tools
|
||||
|
||||
|
||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format = dict(response_format)
|
||||
if response_format.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema",
|
||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_openai_tool_calls(
|
||||
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||
) -> List[ToolCall]:
|
||||
|
|
@ -871,6 +957,40 @@ def _convert_openai_sampling_params(
|
|||
return sampling_params
|
||||
|
||||
|
||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
||||
lls_messages = []
|
||||
for message in messages:
|
||||
lls_message = dict(message)
|
||||
|
||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
||||
if tool_call_id:
|
||||
lls_message["call_id"] = tool_call_id
|
||||
|
||||
content = lls_message.get("content", None)
|
||||
if isinstance(content, list):
|
||||
lls_content = []
|
||||
for item in content:
|
||||
# items can either by pydantic models or dicts here...
|
||||
item = dict(item)
|
||||
if item.get("type", "") == "image_url":
|
||||
lls_item = ImageContentItem(
|
||||
type="image",
|
||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
||||
)
|
||||
elif item.get("type", "") == "text":
|
||||
lls_item = TextContentItem(
|
||||
type="text",
|
||||
text=item.get("text", ""),
|
||||
)
|
||||
lls_content.append(lls_item)
|
||||
lls_message["content"] = lls_content
|
||||
lls_messages.append(lls_message)
|
||||
|
||||
return lls_messages
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> ChatCompletionResponse:
|
||||
|
|
@ -1080,11 +1200,24 @@ async def convert_openai_chat_completion_stream(
|
|||
|
||||
|
||||
async def prepare_openai_completion_params(**params):
|
||||
completion_params = {k: v for k, v in params.items() if v is not None}
|
||||
async def _prepare_value(value: Any) -> Any:
|
||||
new_value = value
|
||||
if isinstance(value, list):
|
||||
new_value = [await _prepare_value(v) for v in value]
|
||||
elif isinstance(value, dict):
|
||||
new_value = {k: await _prepare_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
new_value = value.model_dump(exclude_none=True)
|
||||
return new_value
|
||||
|
||||
completion_params = {}
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
completion_params[k] = await _prepare_value(v)
|
||||
return completion_params
|
||||
|
||||
|
||||
class OpenAICompletionUnsupportedMixin:
|
||||
class OpenAICompletionToLlamaStackMixin:
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1122,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
|
||||
choices = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
# and we may have multiple prompts, if batching was used
|
||||
|
||||
|
|
@ -1134,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
|
||||
index = len(choices)
|
||||
text = result.content
|
||||
finish_reason = _convert_openai_finish_reason(result.stop_reason)
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
|
||||
|
||||
choice = OpenAICompletionChoice(
|
||||
index=index,
|
||||
|
|
@ -1152,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
|
|||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionUnsupportedMixin:
|
||||
class OpenAIChatCompletionToLlamaStackMixin:
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -1167,7 +1301,7 @@ class OpenAIChatCompletionUnsupportedMixin:
|
|||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
|
|
@ -1178,5 +1312,103 @@ class OpenAIChatCompletionUnsupportedMixin:
|
|||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages = _convert_openai_request_messages(messages)
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
response = self.chat_completion(
|
||||
model_id=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
tool_config=tool_config,
|
||||
tools=tools,
|
||||
)
|
||||
outstanding_responses.append(response)
|
||||
|
||||
if stream:
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||
|
||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||
self, model, outstanding_responses
|
||||
)
|
||||
|
||||
async def _process_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
i = 0
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
|
||||
if isinstance(event.delta, TextDelta):
|
||||
text_delta = event.delta.text
|
||||
delta = OpenAIChoiceDelta(content=text_delta)
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
i = i + 1
|
||||
|
||||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")
|
||||
choices = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
completion_message = response.completion_message
|
||||
message = await convert_message_to_openai_dict_new(completion_message)
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
|
||||
|
||||
choice = OpenAIChatCompletionChoice(
|
||||
index=len(choices),
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
)
|
||||
|
|
|
|||
265
llama_stack/providers/utils/scheduler.py
Normal file
265
llama_stack/providers/utils/scheduler.py
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="scheduler")
|
||||
|
||||
|
||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||
# Jobs API for all API flows; e.g. do we need new vs scheduled?
|
||||
class JobStatus(Enum):
|
||||
new = "new"
|
||||
scheduled = "scheduled"
|
||||
running = "running"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
JobID: TypeAlias = str
|
||||
JobType: TypeAlias = str
|
||||
|
||||
|
||||
class JobArtifact(BaseModel):
|
||||
type: JobType
|
||||
name: str
|
||||
# TODO: uri should be a reference to /files API; revisit when /files is implemented
|
||||
uri: str | None = None
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
JobHandler = Callable[
|
||||
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
|
||||
]
|
||||
|
||||
|
||||
LogMessage: TypeAlias = Tuple[datetime, str]
|
||||
|
||||
|
||||
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
|
||||
|
||||
|
||||
class Job:
|
||||
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
|
||||
super().__init__()
|
||||
self.id = job_id
|
||||
self._type = job_type
|
||||
self._handler = handler
|
||||
self._artifacts: list[JobArtifact] = []
|
||||
self._logs: list[LogMessage] = []
|
||||
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
|
||||
|
||||
@property
|
||||
def handler(self) -> JobHandler:
|
||||
return self._handler
|
||||
|
||||
@property
|
||||
def status(self) -> JobStatus:
|
||||
return self._state_transitions[-1][1]
|
||||
|
||||
@status.setter
|
||||
def status(self, status: JobStatus):
|
||||
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
|
||||
raise ValueError(f"Job is already in a completed state ({self.status})")
|
||||
if self.status == status:
|
||||
return
|
||||
self._state_transitions.append((datetime.now(timezone.utc), status))
|
||||
|
||||
@property
|
||||
def artifacts(self) -> list[JobArtifact]:
|
||||
return self._artifacts
|
||||
|
||||
def register_artifact(self, artifact: JobArtifact) -> None:
|
||||
self._artifacts.append(artifact)
|
||||
|
||||
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
|
||||
for date, s in reversed(self._state_transitions):
|
||||
if s in status:
|
||||
return date
|
||||
return None
|
||||
|
||||
@property
|
||||
def scheduled_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.scheduled])
|
||||
|
||||
@property
|
||||
def started_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date([JobStatus.running])
|
||||
|
||||
@property
|
||||
def completed_at(self) -> datetime | None:
|
||||
return self._find_state_transition_date(_COMPLETED_STATUSES)
|
||||
|
||||
@property
|
||||
def logs(self) -> list[LogMessage]:
|
||||
return self._logs[:]
|
||||
|
||||
def append_log(self, message: LogMessage) -> None:
|
||||
self._logs.append(message)
|
||||
|
||||
# TODO: implement
|
||||
def cancel(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _SchedulerBackend(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def shutdown(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _NaiveSchedulerBackend(_SchedulerBackend):
|
||||
def __init__(self, timeout: int = 5):
|
||||
self._timeout = timeout
|
||||
self._loop = asyncio.new_event_loop()
|
||||
# There may be performance implications of using threads due to Python
|
||||
# GIL; may need to measure if it's a real problem though
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
# When stopping the loop, give tasks a chance to finish
|
||||
# TODO: should we explicitly inform jobs of pending stoppage?
|
||||
for task in asyncio.all_tasks(self._loop):
|
||||
self._loop.run_until_complete(task)
|
||||
self._loop.close()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
self._thread.join()
|
||||
|
||||
# TODO: decouple scheduling and running the job
|
||||
def schedule(
|
||||
self,
|
||||
job: Job,
|
||||
on_log_message_cb: Callable[[str], None],
|
||||
on_status_change_cb: Callable[[JobStatus], None],
|
||||
on_artifact_collected_cb: Callable[[JobArtifact], None],
|
||||
) -> None:
|
||||
async def do():
|
||||
try:
|
||||
job.status = JobStatus.running
|
||||
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
|
||||
except Exception as e:
|
||||
on_log_message_cb(str(e))
|
||||
job.status = JobStatus.failed
|
||||
logger.exception(f"Job {job.id} failed.")
|
||||
|
||||
asyncio.run_coroutine_threadsafe(do(), self._loop)
|
||||
|
||||
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
|
||||
pass
|
||||
|
||||
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
pass
|
||||
|
||||
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_BACKENDS = {
|
||||
"naive": _NaiveSchedulerBackend,
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_impl(backend: str) -> _SchedulerBackend:
|
||||
try:
|
||||
return _BACKENDS[backend]()
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown backend {backend}") from e
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, backend: str = "naive"):
|
||||
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
|
||||
self._jobs: dict[JobID, Job] = {}
|
||||
self._backend = _get_backend_impl(backend)
|
||||
|
||||
def _on_log_message_cb(self, job: Job, message: str) -> None:
|
||||
msg = (datetime.now(timezone.utc), message)
|
||||
# At least for the time being, until there's a better way to expose
|
||||
# logs to users, log messages on console
|
||||
logger.info(f"Job {job.id}: {message}")
|
||||
job.append_log(msg)
|
||||
self._backend.on_log_message_cb(job, msg)
|
||||
|
||||
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
|
||||
job.status = status
|
||||
self._backend.on_status_change_cb(job, status)
|
||||
|
||||
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
|
||||
job.register_artifact(artifact)
|
||||
self._backend.on_artifact_collected_cb(job, artifact)
|
||||
|
||||
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
|
||||
job = Job(type_, job_id, handler)
|
||||
if job.id in self._jobs:
|
||||
raise ValueError(f"Job {job.id} already exists")
|
||||
|
||||
self._jobs[job.id] = job
|
||||
job.status = JobStatus.scheduled
|
||||
self._backend.schedule(
|
||||
job,
|
||||
functools.partial(self._on_log_message_cb, job),
|
||||
functools.partial(self._on_status_change_cb, job),
|
||||
functools.partial(self._on_artifact_collected_cb, job),
|
||||
)
|
||||
|
||||
return job.id
|
||||
|
||||
def cancel(self, job_id: JobID) -> None:
|
||||
self.get_job(job_id).cancel()
|
||||
|
||||
def get_job(self, job_id: JobID) -> Job:
|
||||
try:
|
||||
return self._jobs[job_id]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Job {job_id} not found") from e
|
||||
|
||||
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
|
||||
jobs = list(self._jobs.values())
|
||||
if type_:
|
||||
jobs = [job for job in jobs if job._type == type_]
|
||||
return jobs
|
||||
|
||||
async def shutdown(self):
|
||||
# TODO: also cancel jobs once implemented
|
||||
await self._backend.shutdown()
|
||||
|
|
@ -391,6 +391,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
|
|
@ -401,6 +411,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: sambanova/Meta-Llama-3.1-8B-Instruct
|
||||
provider_id: sambanova
|
||||
|
|
|
|||
|
|
@ -158,6 +158,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
|
|
@ -168,6 +178,16 @@ models:
|
|||
provider_id: groq
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
|
|
|
|||
|
|
@ -474,6 +474,16 @@ models:
|
|||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
|
|
@ -484,6 +494,16 @@ models:
|
|||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: groq-openai-compat
|
||||
provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: sambanova/Meta-Llama-3.1-8B-Instruct
|
||||
provider_id: sambanova-openai-compat
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue