mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
ruff fix
This commit is contained in:
parent
e20228a7ca
commit
6c1b1722b4
2 changed files with 35 additions and 11 deletions
|
@ -64,7 +64,9 @@ MODEL_ALIASES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class CentMLInferenceAdapter(
|
||||||
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Adapter to use CentML's serverless inference endpoints,
|
Adapter to use CentML's serverless inference endpoints,
|
||||||
which adhere to the OpenAI chat/completions API spec,
|
which adhere to the OpenAI chat/completions API spec,
|
||||||
|
@ -143,7 +145,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(response, self.formatter)
|
return process_completion_response(response, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -152,7 +156,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(
|
||||||
|
stream, self.formatter
|
||||||
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -242,12 +248,15 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
# For chat requests, always build "messages" from the user messages
|
# For chat requests, always build "messages" from the user messages
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m) for m in request.messages
|
await convert_message_to_openai_dict(m)
|
||||||
|
for m in request.messages
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Non-chat (CompletionRequest)
|
# Non-chat (CompletionRequest)
|
||||||
assert not media_present, "CentML does not support media for completions"
|
assert not media_present, (
|
||||||
|
"CentML does not support media for completions"
|
||||||
|
)
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.formatter
|
||||||
)
|
)
|
||||||
|
@ -256,7 +265,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(
|
||||||
|
request.sampling_params, request.response_format
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _build_options(
|
def _build_options(
|
||||||
|
@ -277,7 +288,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvide
|
||||||
"schema": fmt.json_schema,
|
"schema": fmt.json_schema,
|
||||||
}
|
}
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
raise NotImplementedError("Grammar response format not supported yet")
|
raise NotImplementedError(
|
||||||
|
"Grammar response format not supported yet"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,17 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.providers.remote.inference.centml.config import CentMLImplConfig
|
from llama_stack.providers.remote.inference.centml.config import (
|
||||||
|
CentMLImplConfig,
|
||||||
|
)
|
||||||
|
|
||||||
# If your CentML adapter has a MODEL_ALIASES constant with known model mappings:
|
# If your CentML adapter has a MODEL_ALIASES constant with known model mappings:
|
||||||
from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES
|
from llama_stack.providers.remote.inference.centml.centml import MODEL_ALIASES
|
||||||
|
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import (
|
||||||
|
DistributionTemplate,
|
||||||
|
RunConfigSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
@ -33,7 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
"eval": ["inline::meta-reference"],
|
"eval": ["inline::meta-reference"],
|
||||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
"scoring": [
|
||||||
|
"inline::basic",
|
||||||
|
"inline::llm-as-judge",
|
||||||
|
"inline::braintrust",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
name = "centml"
|
name = "centml"
|
||||||
|
|
||||||
|
@ -94,7 +103,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=default_models + [embedding_model],
|
||||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
default_shields=[
|
||||||
|
ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")
|
||||||
|
],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
run_config_env_vars={
|
run_config_env_vars={
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue