Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-04-14 12:15:44 -05:00 committed by GitHub
commit 488eb8f249
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 2102 additions and 164 deletions

View file

@ -59,8 +59,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
@ -83,8 +83,8 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,

View file

@ -25,8 +25,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -35,8 +35,8 @@ log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,

View file

@ -66,10 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
get_stop_reason,
process_chat_completion_stream_response,
)
@ -176,8 +176,8 @@ def _convert_sampling_params(
class VLLMInferenceImpl(
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
"""

View file

@ -3,13 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl:
def __init__(
self,
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self) -> None:
await self._scheduler.shutdown()
async def shutdown(self):
pass
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now(timezone.utc)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
except Exception:
job_status_response.status = JobStatus.failed
raise
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
else:
raise NotImplementedError()
return post_training_job
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
return self.jobs.get(job_uuid, None)
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -36,10 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -56,8 +56,8 @@ from .models import MODEL_ENTRIES
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)

View file

@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -54,8 +54,8 @@ from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(

View file

@ -34,8 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -61,8 +61,8 @@ model_entries = [
class DatabricksInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks
from openai import AsyncOpenAI
@ -32,13 +32,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
@ -301,6 +308,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@ -320,6 +332,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
@ -336,7 +349,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -347,10 +360,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
@ -374,4 +386,12 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,8 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChoiceDelta,
OpenAIChunkChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAISystemMessageParam,
)
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from .models import MODEL_ENTRIES
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key",
)
self.config = config
self._openai_client = None
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self.config.url}/openai/v1",
api_key=self.config.api_key,
)
return self._openai_client
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Groq does not support json_schema response format, so we need to convert it to json_object
if response_format and response_format.type == "json_schema":
response_format.type = "json_object"
schema = response_format.json_schema.get("schema", {})
response_format.json_schema = None
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
if messages and messages[0].role == "system":
messages[0].content = messages[0].content + json_instructions
else:
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
# Groq returns a 400 error if tools are provided but none are called
# So, set tool_choice to "required" to attempt to force a call
if tools and (not tool_choice or tool_choice == "auto"):
tool_choice = "required"
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
# Groq does not support streaming requests that set response_format
fake_stream = False
if stream and response_format:
params["stream"] = False
fake_stream = True
response = await self._get_openai_client().chat.completions.create(**params)
if fake_stream:
chunk_choices = []
for choice in response.choices:
delta = OpenAIChoiceDelta(
content=choice.message.content,
role=choice.message.role,
tool_calls=choice.message.tool_calls,
)
chunk_choice = OpenAIChunkChoice(
delta=delta,
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=None,
)
chunk_choices.append(chunk_choice)
chunk = OpenAIChatCompletionChunk(
id=response.id,
choices=chunk_choices,
object="chat.completion.chunk",
created=response.created,
model=response.model,
)
async def _fake_stream_generator():
yield chunk
return _fake_stream_generator()
else:
return response

View file

@ -39,8 +39,16 @@ MODEL_ENTRIES = [
"groq/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -35,7 +35,13 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -329,7 +335,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -340,7 +346,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = self.get_provider_model_id(model)
params = await prepare_openai_completion_params(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
@ -39,7 +39,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
@ -337,6 +343,12 @@ class OllamaInferenceAdapter(
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
@ -408,7 +420,7 @@ class OllamaInferenceAdapter(
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -419,7 +431,7 @@ class OllamaInferenceAdapter(
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = {
k: v

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack_client import AsyncLlamaStackClient
@ -26,7 +26,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -266,7 +272,7 @@ class PassthroughInferenceAdapter(Inference):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -277,7 +283,7 @@ class PassthroughInferenceAdapter(Inference):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)

View file

@ -12,8 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -43,8 +43,8 @@ RUNPOD_SUPPORTED_MODELS = {
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)

View file

@ -40,10 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionUnsupportedMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -73,8 +73,8 @@ def build_hf_repo_model_entries():
class _HfAdapter(
Inference,
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from together import AsyncTogether
@ -31,7 +31,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -315,7 +321,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -326,7 +332,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
@ -353,4 +359,26 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p,
user=user,
)
if params.get("stream", True):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from openai import AsyncOpenAI
@ -45,7 +45,12 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
@ -487,7 +492,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -498,7 +503,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,

View file

@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -270,7 +276,7 @@ class LiteLLMOpenAIMixin(
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self._get_model(model)
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
@ -292,7 +298,7 @@ class LiteLLMOpenAIMixin(
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return litellm.text_completion(**params)
return await litellm.atext_completion(**params)
async def openai_chat_completion(
self,
@ -308,7 +314,7 @@ class LiteLLMOpenAIMixin(
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -319,8 +325,8 @@ class LiteLLMOpenAIMixin(
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
model_obj = await self._get_model(model)
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
@ -346,7 +352,7 @@ class LiteLLMOpenAIMixin(
top_p=top_p,
user=user,
)
return litellm.completion(**params)
return await litellm.acompletion(**params)
async def batch_completion(
self,

View file

@ -8,7 +8,7 @@ import logging
import time
import uuid
import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
@ -50,6 +50,18 @@ from openai.types.chat.chat_completion import (
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_chunk import (
Choice as OpenAIChatCompletionChunkChoice,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
@ -59,6 +71,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
URL,
ImageContentItem,
InterleavedContent,
TextContentItem,
@ -85,12 +98,24 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice
from llama_stack.apis.inference.inference import (
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAICompletion,
OpenAICompletionChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolConfig,
)
from llama_stack.apis.inference.inference import (
OpenAIChoice as OpenAIChatCompletionChoice,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
@ -751,6 +776,17 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
return out
def _convert_stop_reason_to_openai_finish_reason(stop_reason: StopReason) -> str:
"""
Convert a StopReason to an OpenAI chat completion finish_reason.
"""
return {
StopReason.end_of_turn: "stop",
StopReason.end_of_message: "tool_calls",
StopReason.out_of_tokens: "length",
}.get(stop_reason, "stop")
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
"""
Convert an OpenAI chat completion finish_reason to a StopReason.
@ -776,6 +812,56 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
}.get(finish_reason, StopReason.end_of_turn)
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
tool_config = ToolConfig()
if tool_choice:
tool_config.tool_choice = tool_choice
return tool_config
def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None) -> List[ToolDefinition]:
lls_tools = []
if not tools:
return lls_tools
for tool in tools:
tool_fn = tool.get("function", {})
tool_name = tool_fn.get("name", None)
tool_desc = tool_fn.get("description", None)
tool_params = tool_fn.get("parameters", None)
lls_tool_params = {}
if tool_params is not None:
tool_param_properties = tool_params.get("properties", {})
for tool_param_key, tool_param_value in tool_param_properties.items():
tool_param_def = ToolParamDefinition(
param_type=tool_param_value.get("type", None),
description=tool_param_value.get("description", None),
)
lls_tool_params[tool_param_key] = tool_param_def
lls_tool = ToolDefinition(
tool_name=tool_name,
description=tool_desc,
parameters=lls_tool_params,
)
lls_tools.append(lls_tool)
return lls_tools
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
if not response_format:
return None
# response_format can be a dict or a pydantic model
response_format = dict(response_format)
if response_format.get("type", "") == "json_schema":
return JsonSchemaResponseFormat(
type="json_schema",
json_schema=response_format.get("json_schema", {}).get("schema", ""),
)
return None
def _convert_openai_tool_calls(
tool_calls: List[OpenAIChatCompletionMessageToolCall],
) -> List[ToolCall]:
@ -871,6 +957,40 @@ def _convert_openai_sampling_params(
return sampling_params
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
# Llama Stack messages and OpenAI messages are similar, but not identical.
lls_messages = []
for message in messages:
lls_message = dict(message)
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
tool_call_id = lls_message.pop("tool_call_id", None)
if tool_call_id:
lls_message["call_id"] = tool_call_id
content = lls_message.get("content", None)
if isinstance(content, list):
lls_content = []
for item in content:
# items can either by pydantic models or dicts here...
item = dict(item)
if item.get("type", "") == "image_url":
lls_item = ImageContentItem(
type="image",
image=URL(uri=item.get("image_url", {}).get("url", "")),
)
elif item.get("type", "") == "text":
lls_item = TextContentItem(
type="text",
text=item.get("text", ""),
)
lls_content.append(lls_item)
lls_message["content"] = lls_content
lls_messages.append(lls_message)
return lls_messages
def convert_openai_chat_completion_choice(
choice: OpenAIChoice,
) -> ChatCompletionResponse:
@ -1080,11 +1200,24 @@ async def convert_openai_chat_completion_stream(
async def prepare_openai_completion_params(**params):
completion_params = {k: v for k, v in params.items() if v is not None}
async def _prepare_value(value: Any) -> Any:
new_value = value
if isinstance(value, list):
new_value = [await _prepare_value(v) for v in value]
elif isinstance(value, dict):
new_value = {k: await _prepare_value(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
new_value = value.model_dump(exclude_none=True)
return new_value
completion_params = {}
for k, v in params.items():
if v is not None:
completion_params[k] = await _prepare_value(v)
return completion_params
class OpenAICompletionUnsupportedMixin:
class OpenAICompletionToLlamaStackMixin:
async def openai_completion(
self,
model: str,
@ -1122,6 +1255,7 @@ class OpenAICompletionUnsupportedMixin:
choices = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
# and we may have multiple prompts, if batching was used
@ -1134,7 +1268,7 @@ class OpenAICompletionUnsupportedMixin:
index = len(choices)
text = result.content
finish_reason = _convert_openai_finish_reason(result.stop_reason)
finish_reason = _convert_stop_reason_to_openai_finish_reason(result.stop_reason)
choice = OpenAICompletionChoice(
index=index,
@ -1152,7 +1286,7 @@ class OpenAICompletionUnsupportedMixin:
)
class OpenAIChatCompletionUnsupportedMixin:
class OpenAIChatCompletionToLlamaStackMixin:
async def openai_chat_completion(
self,
model: str,
@ -1167,7 +1301,7 @@ class OpenAIChatCompletionUnsupportedMixin:
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
@ -1178,5 +1312,103 @@ class OpenAIChatCompletionUnsupportedMixin:
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages = _convert_openai_request_messages(messages)
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools)
outstanding_responses = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
response = self.chat_completion(
model_id=model,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
tool_config=tool_config,
tools=tools,
)
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses
)
async def _process_stream_response(
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
):
id = f"chatcmpl-{uuid.uuid4()}"
for outstanding_response in outstanding_responses:
response = await outstanding_response
i = 0
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
if isinstance(event.delta, TextDelta):
text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk(
id=id,
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name, arguments=tool_call.arguments_json
),
)
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
],
created=int(time.time()),
model=model,
object="chat.completion.chunk",
)
i = i + 1
async def _process_non_stream_response(
self, model: str, outstanding_responses: List[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
raise ValueError(f"{self.__class__.__name__} doesn't support openai chat completion")
choices = []
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
message = await convert_message_to_openai_dict_new(completion_message)
finish_reason = _convert_stop_reason_to_openai_finish_reason(completion_message.stop_reason)
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
finish_reason=finish_reason,
)
choices.append(choice)
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
created=int(time.time()),
model=model,
object="chat.completion",
)

View file

@ -0,0 +1,265 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import abc
import asyncio
import functools
import threading
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Coroutine, Dict, Iterable, Tuple, TypeAlias
from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
# TODO: revisit the list of possible statuses when defining a more coherent
# Jobs API for all API flows; e.g. do we need new vs scheduled?
class JobStatus(Enum):
new = "new"
scheduled = "scheduled"
running = "running"
failed = "failed"
completed = "completed"
JobID: TypeAlias = str
JobType: TypeAlias = str
class JobArtifact(BaseModel):
type: JobType
name: str
# TODO: uri should be a reference to /files API; revisit when /files is implemented
uri: str | None = None
metadata: Dict[str, Any]
JobHandler = Callable[
[Callable[[str], None], Callable[[JobStatus], None], Callable[[JobArtifact], None]], Coroutine[Any, Any, None]
]
LogMessage: TypeAlias = Tuple[datetime, str]
_COMPLETED_STATUSES = {JobStatus.completed, JobStatus.failed}
class Job:
def __init__(self, job_type: JobType, job_id: JobID, handler: JobHandler):
super().__init__()
self.id = job_id
self._type = job_type
self._handler = handler
self._artifacts: list[JobArtifact] = []
self._logs: list[LogMessage] = []
self._state_transitions: list[Tuple[datetime, JobStatus]] = [(datetime.now(timezone.utc), JobStatus.new)]
@property
def handler(self) -> JobHandler:
return self._handler
@property
def status(self) -> JobStatus:
return self._state_transitions[-1][1]
@status.setter
def status(self, status: JobStatus):
if status in _COMPLETED_STATUSES and self.status in _COMPLETED_STATUSES:
raise ValueError(f"Job is already in a completed state ({self.status})")
if self.status == status:
return
self._state_transitions.append((datetime.now(timezone.utc), status))
@property
def artifacts(self) -> list[JobArtifact]:
return self._artifacts
def register_artifact(self, artifact: JobArtifact) -> None:
self._artifacts.append(artifact)
def _find_state_transition_date(self, status: Iterable[JobStatus]) -> datetime | None:
for date, s in reversed(self._state_transitions):
if s in status:
return date
return None
@property
def scheduled_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.scheduled])
@property
def started_at(self) -> datetime | None:
return self._find_state_transition_date([JobStatus.running])
@property
def completed_at(self) -> datetime | None:
return self._find_state_transition_date(_COMPLETED_STATUSES)
@property
def logs(self) -> list[LogMessage]:
return self._logs[:]
def append_log(self, message: LogMessage) -> None:
self._logs.append(message)
# TODO: implement
def cancel(self) -> None:
raise NotImplementedError
class _SchedulerBackend(abc.ABC):
@abc.abstractmethod
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
raise NotImplementedError
@abc.abstractmethod
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
raise NotImplementedError
@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
raise NotImplementedError
class _NaiveSchedulerBackend(_SchedulerBackend):
def __init__(self, timeout: int = 5):
self._timeout = timeout
self._loop = asyncio.new_event_loop()
# There may be performance implications of using threads due to Python
# GIL; may need to measure if it's a real problem though
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
# When stopping the loop, give tasks a chance to finish
# TODO: should we explicitly inform jobs of pending stoppage?
for task in asyncio.all_tasks(self._loop):
self._loop.run_until_complete(task)
self._loop.close()
async def shutdown(self) -> None:
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# TODO: decouple scheduling and running the job
def schedule(
self,
job: Job,
on_log_message_cb: Callable[[str], None],
on_status_change_cb: Callable[[JobStatus], None],
on_artifact_collected_cb: Callable[[JobArtifact], None],
) -> None:
async def do():
try:
job.status = JobStatus.running
await job.handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb)
except Exception as e:
on_log_message_cb(str(e))
job.status = JobStatus.failed
logger.exception(f"Job {job.id} failed.")
asyncio.run_coroutine_threadsafe(do(), self._loop)
def on_log_message_cb(self, job: Job, message: LogMessage) -> None:
pass
def on_status_change_cb(self, job: Job, status: JobStatus) -> None:
pass
def on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
pass
_BACKENDS = {
"naive": _NaiveSchedulerBackend,
}
def _get_backend_impl(backend: str) -> _SchedulerBackend:
try:
return _BACKENDS[backend]()
except KeyError as e:
raise ValueError(f"Unknown backend {backend}") from e
class Scheduler:
def __init__(self, backend: str = "naive"):
# TODO: if server crashes, job states are lost; we need to persist jobs on disc
self._jobs: dict[JobID, Job] = {}
self._backend = _get_backend_impl(backend)
def _on_log_message_cb(self, job: Job, message: str) -> None:
msg = (datetime.now(timezone.utc), message)
# At least for the time being, until there's a better way to expose
# logs to users, log messages on console
logger.info(f"Job {job.id}: {message}")
job.append_log(msg)
self._backend.on_log_message_cb(job, msg)
def _on_status_change_cb(self, job: Job, status: JobStatus) -> None:
job.status = status
self._backend.on_status_change_cb(job, status)
def _on_artifact_collected_cb(self, job: Job, artifact: JobArtifact) -> None:
job.register_artifact(artifact)
self._backend.on_artifact_collected_cb(job, artifact)
def schedule(self, type_: JobType, job_id: JobID, handler: JobHandler) -> JobID:
job = Job(type_, job_id, handler)
if job.id in self._jobs:
raise ValueError(f"Job {job.id} already exists")
self._jobs[job.id] = job
job.status = JobStatus.scheduled
self._backend.schedule(
job,
functools.partial(self._on_log_message_cb, job),
functools.partial(self._on_status_change_cb, job),
functools.partial(self._on_artifact_collected_cb, job),
)
return job.id
def cancel(self, job_id: JobID) -> None:
self.get_job(job_id).cancel()
def get_job(self, job_id: JobID) -> Job:
try:
return self._jobs[job_id]
except KeyError as e:
raise ValueError(f"Job {job_id} not found") from e
def get_jobs(self, type_: JobType | None = None) -> list[Job]:
jobs = list(self._jobs.values())
if type_:
jobs = [job for job in jobs if job._type == type_]
return jobs
async def shutdown(self):
# TODO: also cancel jobs once implemented
await self._backend.shutdown()