Memory tests pass now

This commit is contained in:
Ashwin Bharambe 2024-12-15 20:55:06 -08:00
parent e51154964f
commit 59ce047aea
23 changed files with 122 additions and 81 deletions

View file

@ -9,21 +9,24 @@ import json
import uuid
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
content_has_media,
interleaved_content_as_str,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
MODEL_ALIASES = [
@ -64,7 +67,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -449,7 +452,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
@ -457,7 +460,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(

View file

@ -69,7 +69,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -166,11 +166,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if type(request) == ChatCompletionRequest:
if isinstance(request, ChatCompletionRequest):
prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif type(request) == CompletionRequest:
elif isinstance(request, CompletionRequest):
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
@ -185,6 +185,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -62,7 +62,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -135,6 +135,6 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,14 +8,7 @@ import warnings
from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
ImageMedia,
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
from llama_models.sku_list import CoreModelId
from openai import APIConnectionError, AsyncOpenAI
@ -28,13 +21,17 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
ToolChoice,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
from .openai_utils import (
@ -123,17 +120,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if isinstance(content, ImageMedia) or (
isinstance(content, list)
and any(isinstance(c, ImageMedia) for c in content)
):
raise NotImplementedError("ImageMedia is not supported")
if content_has_media(content):
raise NotImplementedError("Media is not supported")
await check_health(self._config) # this raises errors
@ -165,7 +159,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()