mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
fix top k, add in comments
This commit is contained in:
parent
941d5f1b18
commit
d1f67d90ca
1 changed files with 25 additions and 23 deletions
|
@ -64,9 +64,8 @@ MODEL_ALIASES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CentMLInferenceAdapter(
|
class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
NeedsRequestProviderData):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Adapter to use CentML's serverless inference endpoints,
|
Adapter to use CentML's serverless inference endpoints,
|
||||||
which adhere to the OpenAI chat/completions API spec,
|
which adhere to the OpenAI chat/completions API spec,
|
||||||
|
@ -138,16 +137,14 @@ class CentMLInferenceAdapter(
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
# Using the older "completions" route for non-chat
|
# Using the older "completions" route for non-chat
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(response)
|
return process_completion_response(response)
|
||||||
|
|
||||||
async def _stream_completion(
|
async def _stream_completion(self,
|
||||||
self, request: CompletionRequest
|
request: CompletionRequest) -> AsyncGenerator:
|
||||||
) -> AsyncGenerator:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -156,9 +153,7 @@ class CentMLInferenceAdapter(
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(
|
async for chunk in process_completion_stream_response(stream):
|
||||||
stream
|
|
||||||
):
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -200,8 +195,7 @@ class CentMLInferenceAdapter(
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# For chat requests, if "messages" is in params -> .chat.completions
|
# For chat requests, if "messages" is in params -> .chat.completions
|
||||||
|
@ -214,8 +208,7 @@ class CentMLInferenceAdapter(
|
||||||
return process_chat_completion_response(response, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
) -> AsyncGenerator:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -228,29 +221,37 @@ class CentMLInferenceAdapter(
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
stream, request):
|
stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
# HELPER METHODS
|
# HELPER METHODS
|
||||||
#
|
#
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(
|
||||||
|
self, request: Union[ChatCompletionRequest,
|
||||||
|
CompletionRequest]) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present or not llama_model:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
input_dict["messages"] = [
|
||||||
|
await convert_message_to_openai_dict(m)
|
||||||
|
for m in request.messages
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
|
request, llama_model)
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": request.model,
|
"model":
|
||||||
|
request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream":
|
||||||
|
request.stream,
|
||||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||||
}
|
}
|
||||||
logcat.debug("inference", f"params to centml: {params}")
|
logcat.debug("inference", f"params to centml: {params}")
|
||||||
|
@ -267,6 +268,8 @@ class CentMLInferenceAdapter(
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
options["response_format"] = {
|
options["response_format"] = {
|
||||||
"type": "json_object",
|
"type": "json_object",
|
||||||
|
# CentML currently does not support guided decoding,
|
||||||
|
# the following setting is currently ignored by the server.
|
||||||
"schema": fmt.json_schema,
|
"schema": fmt.json_schema,
|
||||||
}
|
}
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
|
@ -295,8 +298,7 @@ class CentMLInferenceAdapter(
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
# CentML does not support media for embeddings.
|
# CentML does not support media for embeddings.
|
||||||
assert all(not content_has_media(c) for c in contents), (
|
assert all(not content_has_media(c) for c in contents), (
|
||||||
"CentML does not support media for embeddings"
|
"CentML does not support media for embeddings")
|
||||||
)
|
|
||||||
|
|
||||||
resp = self._get_client().embeddings.create(
|
resp = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue