mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 12:42:17 +00:00
Update Strategy in SamplingParams to be a union
This commit is contained in:
parent
300e6e2702
commit
dea575c994
28 changed files with 600 additions and 377 deletions
|
|
@ -9,6 +9,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import TopKSamplingStrategy
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
|
@ -172,7 +173,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
if request.sampling_params and request.sampling_params.top_k:
|
||||
if request.sampling_params and isinstance(
|
||||
request.sampling_params.strategy, TopKSamplingStrategy
|
||||
):
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue