mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Code refactoring and removing dead code
This commit is contained in:
parent
ef0736527d
commit
f6080040da
6 changed files with 302 additions and 137 deletions
|
@ -9,12 +9,10 @@ from typing import Any
|
|||
|
||||
from ibm_watsonx_ai.foundation_models import Model
|
||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
|
@ -48,6 +46,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
completion_request_to_prompt,
|
||||
request_has_media,
|
||||
)
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from . import WatsonXConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
@ -85,7 +84,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
pass
|
||||
|
||||
def _get_client(self, model_id) -> Model:
|
||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
config_api_key = (
|
||||
self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||
)
|
||||
config_url = self._config.url
|
||||
project_id = self._config.project_id
|
||||
credentials = {"url": config_url, "apikey": config_api_key}
|
||||
|
@ -132,14 +133,18 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = self._get_client(request.model).generate(**params)
|
||||
choices = []
|
||||
if "results" in r:
|
||||
for result in r["results"]:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||
finish_reason=(
|
||||
result["stop_reason"] if result["stop_reason"] else None
|
||||
),
|
||||
text=result["generated_text"],
|
||||
)
|
||||
choices.append(choice)
|
||||
|
@ -148,7 +153,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
return process_chat_completion_response(response, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
model_id = request.model
|
||||
|
||||
|
@ -168,28 +175,44 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {"params": {}}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request, llama_model
|
||||
)
|
||||
else:
|
||||
assert not media_present, "Together does not support media for Completion requests"
|
||||
assert (
|
||||
not media_present
|
||||
), "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||
if request.sampling_params:
|
||||
if request.sampling_params.strategy:
|
||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||
input_dict["params"][
|
||||
GenParams.DECODING_METHOD
|
||||
] = request.sampling_params.strategy.type
|
||||
if request.sampling_params.max_tokens:
|
||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||
input_dict["params"][
|
||||
GenParams.MAX_NEW_TOKENS
|
||||
] = request.sampling_params.max_tokens
|
||||
if request.sampling_params.repetition_penalty:
|
||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||
input_dict["params"][
|
||||
GenParams.REPETITION_PENALTY
|
||||
] = request.sampling_params.repetition_penalty
|
||||
|
||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||
input_dict["params"][
|
||||
GenParams.TOP_P
|
||||
] = request.sampling_params.strategy.top_p
|
||||
input_dict["params"][
|
||||
GenParams.TEMPERATURE
|
||||
] = request.sampling_params.strategy.temperature
|
||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||
input_dict["params"][
|
||||
GenParams.TOP_K
|
||||
] = request.sampling_params.strategy.top_k
|
||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue