mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 13:00:52 +00:00
chore: remove /v1/inference/completion and implementations (#3622)
# What does this PR do? the /inference/completion route is gone. this removes the implementations. ## Test Plan ci
This commit is contained in:
parent
ea15f2a270
commit
f7c5ef4ec0
75 changed files with 16141 additions and 17056 deletions
|
@ -8,14 +8,9 @@ from collections.abc import AsyncGenerator
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -37,13 +32,10 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -94,79 +86,6 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
return prompt[len("<|begin_of_text|>") :]
|
||||
return prompt
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: SamplingParams | None,
|
||||
fmt: ResponseFormat,
|
||||
logprobs: LogProbConfig | None,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
options["logprobs"] = logprobs.top_k
|
||||
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
|
||||
raise ValueError("Required range: 0 < top_k < 5")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -222,22 +141,46 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: SamplingParams | None,
|
||||
fmt: ResponseFormat | None,
|
||||
logprobs: LogProbConfig | None,
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
if logprobs and logprobs.top_k:
|
||||
options["logprobs"] = logprobs.top_k
|
||||
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
|
||||
raise ValueError("Required range: 0 < top_k < 5")
|
||||
|
||||
return options
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO: tools are never added to the request, so we need to add them here
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
# TODO: tools are never added to the request, so we need to add them here
|
||||
if media_present or not llama_model:
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||
else:
|
||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue