This commit is contained in:
Honglin Cao 2025-01-21 17:02:28 -05:00
parent e20228a7ca
commit 6c1b1722b4
2 changed files with 35 additions and 11 deletions

View file

@ -64,7 +64,9 @@ MODEL_ALIASES = [
] ]
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class CentMLInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
""" """
Adapter to use CentML's serverless inference endpoints, Adapter to use CentML's serverless inference endpoints,
which adhere to the OpenAI chat/completions API spec, which adhere to the OpenAI chat/completions API spec,
@ -143,7 +145,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
response = self._get_client().completions.create(**params) response = self._get_client().completions.create(**params)
return process_completion_response(response, self.formatter) return process_completion_response(response, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
@ -152,7 +156,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(
stream, self.formatter
):
yield chunk yield chunk
# #
@ -242,12 +248,15 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
# For chat requests, always build "messages" from the user messages # For chat requests, always build "messages" from the user messages
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages await convert_message_to_openai_dict(m)
for m in request.messages
] ]
else: else:
# Non-chat (CompletionRequest) # Non-chat (CompletionRequest)
assert not media_present, "CentML does not support media for completions" assert not media_present, (
"CentML does not support media for completions"
)
input_dict["prompt"] = await completion_request_to_prompt( input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter request, self.formatter
) )
@ -256,7 +265,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options(request.sampling_params, request.response_format), **self._build_options(
request.sampling_params, request.response_format
),
} }
def _build_options( def _build_options(
@ -277,7 +288,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
"schema": fmt.json_schema, "schema": fmt.json_schema,
} }
elif fmt.type == ResponseFormatType.grammar.value: elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet") raise NotImplementedError(
"Grammar response format not supported yet"
)
else: else:
raise ValueError(f"Unknown response format {fmt.type}") raise ValueError(f"Unknown response format {fmt.type}")

View file

@ -13,12 +13,17 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.centml.config import CentMLImplConfig from llama_stack.providers.remote.inference.centml.config import (
CentMLImplConfig,
)
# If your CentML adapter has a MODEL_ALIASES constant with known model mappings: # If your CentML adapter has a MODEL_ALIASES constant with known model mappings:
from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
)
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
@ -33,7 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
"telemetry": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"], "eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"], "datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "scoring": [
"inline::basic",
"inline::llm-as-judge",
"inline::braintrust",
],
} }
name = "centml" name = "centml"
@ -94,7 +103,9 @@ def get_distribution_template() -> DistributionTemplate:
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=default_models + [embedding_model], default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], default_shields=[
ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")
],
), ),
}, },
run_config_env_vars={ run_config_env_vars={