mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
ruff fix and format
This commit is contained in:
parent
dc1ff40413
commit
e20228a7ca
2 changed files with 26 additions and 26 deletions
|
@ -8,9 +8,11 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import CentMLImplConfig
|
from .config import CentMLImplConfig
|
||||||
|
|
||||||
|
|
||||||
class CentMLProviderDataValidator(BaseModel):
|
class CentMLProviderDataValidator(BaseModel):
|
||||||
centml_api_key: str
|
centml_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: CentMLImplConfig, _deps):
|
async def get_adapter_impl(config: CentMLImplConfig, _deps):
|
||||||
"""
|
"""
|
||||||
Factory function to construct and initialize the CentML adapter.
|
Factory function to construct and initialize the CentML adapter.
|
||||||
|
@ -21,9 +23,9 @@ async def get_adapter_impl(config: CentMLImplConfig, _deps):
|
||||||
from .centml import CentMLInferenceAdapter
|
from .centml import CentMLInferenceAdapter
|
||||||
|
|
||||||
# Ensure the provided config is indeed a CentMLImplConfig
|
# Ensure the provided config is indeed a CentMLImplConfig
|
||||||
assert isinstance(
|
assert isinstance(config, CentMLImplConfig), (
|
||||||
config, CentMLImplConfig
|
f"Unexpected config type: {type(config)}"
|
||||||
), f"Unexpected config type: {type(config)}"
|
)
|
||||||
|
|
||||||
# Instantiate and initialize the adapter
|
# Instantiate and initialize the adapter
|
||||||
adapter = CentMLInferenceAdapter(config)
|
adapter = CentMLInferenceAdapter(config)
|
||||||
|
|
|
@ -42,7 +42,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -65,8 +64,7 @@ MODEL_ALIASES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
NeedsRequestProviderData):
|
|
||||||
"""
|
"""
|
||||||
Adapter to use CentML's serverless inference endpoints,
|
Adapter to use CentML's serverless inference endpoints,
|
||||||
which adhere to the OpenAI chat/completions API spec,
|
which adhere to the OpenAI chat/completions API spec,
|
||||||
|
@ -138,14 +136,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest) -> ChatCompletionResponse:
|
self, request: CompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
# Using the older "completions" route for non-chat
|
# Using the older "completions" route for non-chat
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(response, self.formatter)
|
return process_completion_response(response, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(self,
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
request: CompletionRequest) -> AsyncGenerator:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -154,8 +152,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
stream, self.formatter):
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -195,7 +192,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# For chat requests, if "messages" is in params -> .chat.completions
|
# For chat requests, if "messages" is in params -> .chat.completions
|
||||||
|
@ -208,7 +206,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
return process_chat_completion_response(response, self.formatter)
|
return process_chat_completion_response(response, self.formatter)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest) -> AsyncGenerator:
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -221,7 +220,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
stream, self.formatter):
|
stream, self.formatter
|
||||||
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -229,8 +229,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
#
|
#
|
||||||
|
|
||||||
async def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest,
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
CompletionRequest]) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Build the 'params' dict that the OpenAI (CentML) client expects.
|
Build the 'params' dict that the OpenAI (CentML) client expects.
|
||||||
For chat requests, we always prefer "messages" so that it calls
|
For chat requests, we always prefer "messages" so that it calls
|
||||||
|
@ -242,22 +242,20 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
# For chat requests, always build "messages" from the user messages
|
# For chat requests, always build "messages" from the user messages
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m)
|
await convert_message_to_openai_dict(m) for m in request.messages
|
||||||
for m in request.messages
|
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Non-chat (CompletionRequest)
|
# Non-chat (CompletionRequest)
|
||||||
assert not media_present, "CentML does not support media for completions"
|
assert not media_present, "CentML does not support media for completions"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
request, self.formatter)
|
request, self.formatter
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model":
|
"model": request.model,
|
||||||
request.model,
|
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream":
|
"stream": request.stream,
|
||||||
request.stream,
|
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,8 +277,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
"schema": fmt.json_schema,
|
"schema": fmt.json_schema,
|
||||||
}
|
}
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Grammar response format not supported yet")
|
||||||
"Grammar response format not supported yet")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
@ -297,8 +294,9 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
# CentML does not support media
|
# CentML does not support media
|
||||||
assert all(not content_has_media(c) for c in contents), \
|
assert all(not content_has_media(c) for c in contents), (
|
||||||
"CentML does not support media for embeddings"
|
"CentML does not support media for embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
resp = self._get_client().embeddings.create(
|
resp = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue