Tests pass with Ollama now

This commit is contained in:
Ashwin Bharambe 2024-12-15 17:31:21 -08:00
parent a9a041a1de
commit e51154964f
27 changed files with 83 additions and 65 deletions

View file

@ -9,8 +9,6 @@ import logging
from typing import List
from llama_models.llama3.api.datatypes import Message
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)

View file

@ -7,13 +7,17 @@
import logging
from typing import Any, Dict, List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.inference import Message
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner",
"CodeShield",
@ -48,7 +52,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
from codeshield.cs import CodeShield
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
log.info(f"Running CodeScannerShield on {text[50:]}")
result = await CodeShield.scan_code(text)

View file

@ -10,7 +10,6 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403

View file

@ -10,7 +10,6 @@ from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI

View file

@ -10,7 +10,6 @@ from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData

View file

@ -11,7 +11,6 @@ import httpx
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
@ -90,7 +89,7 @@ model_aliases = [
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama3.2-vision",
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(

View file

@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -267,7 +267,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -10,7 +10,6 @@ from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together

View file

@ -8,7 +8,6 @@ import logging
from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models

View file

@ -7,7 +7,6 @@
from pathlib import Path
import pytest
from PIL import Image as PIL_Image
from llama_models.llama3.api.datatypes import * # noqa: F403
@ -17,6 +16,9 @@ from .utils import group_chunks
THIS_DIR = Path(__file__).parent
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
PASTA_IMAGE = f.read()
class TestVisionModelInference:
@pytest.mark.asyncio
@ -24,12 +26,12 @@ class TestVisionModelInference:
"image, expected_strings",
[
(
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
ImageContentItem(data=PASTA_IMAGE),
["spaghetti"],
),
(
ImageMedia(
image=URL(
ImageContentItem(
data=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -58,7 +60,12 @@ class TestVisionModelInference:
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(content=[image, "Describe this image in two sentences."]),
UserMessage(
content=[
image,
TextContentItem(text="Describe this image in two sentences."),
]
),
],
stream=False,
sampling_params=SamplingParams(max_tokens=100),
@ -89,8 +96,8 @@ class TestVisionModelInference:
)
images = [
ImageMedia(
image=URL(
ImageContentItem(
data=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
@ -106,7 +113,12 @@ class TestVisionModelInference:
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[image, "Describe this image in two sentences."]
content=[
image,
TextContentItem(
text="Describe this image in two sentences."
),
]
),
],
stream=True,

View file

@ -7,8 +7,8 @@
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.deployment_types import URL
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput

View file

@ -10,7 +10,7 @@ from urllib.parse import unquote
import pandas
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.deployment_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url

View file

@ -7,9 +7,11 @@
import logging
from typing import List
from llama_models.llama3.api.datatypes import InterleavedTextMedia
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
from llama_stack.apis.inference import (
EmbeddingsResponse,
InterleavedContent,
ModelStore,
)
EMBEDDING_MODELS = {}
@ -23,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(

View file

@ -93,11 +93,15 @@ def process_chat_completion_response(
) -> ChatCompletionResponse:
choice = response.choices[0]
completion_message = formatter.decode_assistant_message_from_content(
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
return ChatCompletionResponse(
completion_message=completion_message,
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=None,
)

View file

@ -6,6 +6,7 @@
import asyncio
import base64
import io
import json
import logging
import re
@ -21,7 +22,6 @@ from llama_models.llama3.api.datatypes import (
RawMediaItem,
RawTextItem,
Role,
ToolChoice,
ToolPromptFormat,
)
from llama_models.llama3.prompt_templates import (
@ -47,6 +47,7 @@ from llama_stack.apis.inference import (
ResponseFormatType,
SystemMessage,
TextContentItem,
ToolChoice,
UserMessage,
)
@ -136,7 +137,7 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
r = await client.get(media.data.uri)
content = r.content
content_type = r.headers.get("content-type")
if content_type:
@ -145,7 +146,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
format = "png"
return content, format
else:
image = PIL_Image.open(media.data)
image = PIL_Image.open(io.BytesIO(media.data))
return media.data, image.format
@ -153,7 +154,7 @@ async def convert_image_content_to_url(
media: ImageContentItem, download: bool = False, include_format: bool = True
) -> str:
if isinstance(media.data, URL) and not download:
return media.image.uri
return media.data.uri
content, format = await localize_image_content(media)
if include_format:

View file

@ -8,7 +8,7 @@ import base64
import mimetypes
import os
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.deployment_types import URL
def data_url_from_file(file_path: str) -> URL: