Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -69,7 +69,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -166,11 +166,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if type(request) == ChatCompletionRequest:
if isinstance(request, ChatCompletionRequest):
prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif type(request) == CompletionRequest:
elif isinstance(request, CompletionRequest):
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
@ -185,6 +185,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()