This commit is contained in:
Honglin Cao 2025-01-21 17:02:28 -05:00
parent e20228a7ca
commit 6c1b1722b4
2 changed files with 35 additions and 11 deletions

View file

@ -64,7 +64,9 @@ MODEL_ALIASES = [
]
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class CentMLInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
"""
Adapter to use CentML's serverless inference endpoints,
which adhere to the OpenAI chat/completions API spec,
@ -143,7 +145,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
response = self._get_client().completions.create(**params)
return process_completion_response(response, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
@ -152,7 +156,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
async for chunk in process_completion_stream_response(
stream, self.formatter
):
yield chunk
#
@ -242,12 +248,15 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
if isinstance(request, ChatCompletionRequest):
# For chat requests, always build "messages" from the user messages
input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages
await convert_message_to_openai_dict(m)
for m in request.messages
]
else:
# Non-chat (CompletionRequest)
assert not media_present, "CentML does not support media for completions"
assert not media_present, (
"CentML does not support media for completions"
)
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter
)
@ -256,7 +265,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
**self._build_options(
request.sampling_params, request.response_format
),
}
def _build_options(
@ -277,7 +288,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
raise NotImplementedError(
"Grammar response format not supported yet"
)
else:
raise ValueError(f"Unknown response format {fmt.type}")

View file

@ -13,12 +13,17 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.centml.config import CentMLImplConfig
from llama_stack.providers.remote.inference.centml.config import (
CentMLImplConfig,
)
# If your CentML adapter has a MODEL_ALIASES constant with known model mappings:
from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
)
def get_distribution_template() -> DistributionTemplate:
@ -33,7 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"scoring": [
"inline::basic",
"inline::llm-as-judge",
"inline::braintrust",
],
}
name = "centml"
@ -94,7 +103,9 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider],
},
default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_shields=[
ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")
],
),
},
run_config_env_vars={