mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 02:59:48 +00:00
Tests pass with Ollama now
This commit is contained in:
parent
a9a041a1de
commit
e51154964f
27 changed files with 83 additions and 65 deletions
|
|
@ -7,9 +7,11 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||
|
||||
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
InterleavedContent,
|
||||
ModelStore,
|
||||
)
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
|
@ -23,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(
|
||||
|
|
|
|||
|
|
@ -93,11 +93,15 @@ def process_chat_completion_response(
|
|||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
completion_message = formatter.decode_assistant_message_from_content(
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=completion_message,
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -21,7 +22,6 @@ from llama_models.llama3.api.datatypes import (
|
|||
RawMediaItem,
|
||||
RawTextItem,
|
||||
Role,
|
||||
ToolChoice,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.prompt_templates import (
|
||||
|
|
@ -47,6 +47,7 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormatType,
|
||||
SystemMessage,
|
||||
TextContentItem,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
|
|
@ -136,7 +137,7 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
|||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(media.image.uri)
|
||||
r = await client.get(media.data.uri)
|
||||
content = r.content
|
||||
content_type = r.headers.get("content-type")
|
||||
if content_type:
|
||||
|
|
@ -145,7 +146,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
|||
format = "png"
|
||||
return content, format
|
||||
else:
|
||||
image = PIL_Image.open(media.data)
|
||||
image = PIL_Image.open(io.BytesIO(media.data))
|
||||
return media.data, image.format
|
||||
|
||||
|
||||
|
|
@ -153,7 +154,7 @@ async def convert_image_content_to_url(
|
|||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||
) -> str:
|
||||
if isinstance(media.data, URL) and not download:
|
||||
return media.image.uri
|
||||
return media.data.uri
|
||||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue