mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Tests pass with Ollama now
This commit is contained in:
parent
a9a041a1de
commit
e51154964f
27 changed files with 83 additions and 65 deletions
|
@ -29,11 +29,13 @@ from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
content: InterleavedTextMedia | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,20 +104,20 @@ class _MemoryBankConfigCommon(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
type: Literal["vector"] = "vector"
|
||||||
|
|
||||||
|
|
||||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
type: Literal["keyvalue"] = "keyvalue"
|
||||||
keys: List[str] # what keys to focus on
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
type: Literal["keyword"] = "keyword"
|
||||||
|
|
||||||
|
|
||||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
type: Literal["graph"] = "graph"
|
||||||
entities: List[str] # what entities to focus on
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
@ -230,7 +232,7 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
StepType.memory_retrieval.value
|
StepType.memory_retrieval.value
|
||||||
)
|
)
|
||||||
memory_bank_ids: List[str]
|
memory_bank_ids: List[str]
|
||||||
inserted_context: InterleavedTextMedia
|
inserted_context: InterleavedContent
|
||||||
|
|
||||||
|
|
||||||
Step = Annotated[
|
Step = Annotated[
|
||||||
|
|
|
@ -17,7 +17,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionRequest(BaseModel):
|
class BatchCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content_batch: List[InterleavedTextMedia]
|
content_batch: List[InterleavedContent]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class BatchInference(Protocol):
|
||||||
async def batch_completion(
|
async def batch_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedTextMedia],
|
content_batch: List[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> BatchCompletionResponse: ...
|
||||||
|
|
|
@ -10,6 +10,8 @@ from typing import Optional
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingMetric(BaseModel):
|
class PostTrainingMetric(BaseModel):
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
from llama_stack.apis.scoring import * # noqa: F403
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -247,7 +247,7 @@ class CompletionResponseStreamChunk(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BatchCompletionRequest(BaseModel):
|
class BatchCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
content_batch: List[InterleavedTextMedia]
|
content_batch: List[InterleavedContent]
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
|
@ -8,27 +8,27 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBank
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MemoryBankDocument(BaseModel):
|
class MemoryBankDocument(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
content: InterleavedTextMedia | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedTextMedia
|
content: InterleavedContent
|
||||||
token_count: int
|
token_count: int
|
||||||
document_id: str
|
document_id: str
|
||||||
|
|
||||||
|
@ -62,6 +62,6 @@ class Memory(Protocol):
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse: ...
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
|
@ -5,16 +5,16 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_stack.apis.shields import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ViolationLevel(Enum):
|
class ViolationLevel(Enum):
|
||||||
|
|
|
@ -13,6 +13,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
class FilteringFunction(Enum):
|
||||||
|
|
|
@ -59,7 +59,7 @@ class MemoryRouter(Memory):
|
||||||
async def query_documents(
|
async def query_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
query: InterleavedTextMedia,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryDocumentsResponse:
|
) -> QueryDocumentsResponse:
|
||||||
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
return await self.routing_table.get_provider_impl(bank_id).query_documents(
|
||||||
|
@ -133,7 +133,7 @@ class InferenceRouter(Inference):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -163,7 +163,7 @@ class InferenceRouter(Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -16,8 +16,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
from llama_models.llama3.api.datatypes import URL
|
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
@ -30,7 +29,6 @@ def get_impl_api(p: Any) -> Api:
|
||||||
|
|
||||||
# TODO: this should return the registered object for all APIs
|
# TODO: this should return the registered object for all APIs
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
|
@ -76,7 +74,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
||||||
async def add_objects(
|
async def add_objects(
|
||||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -9,8 +9,6 @@ import logging
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
@ -7,13 +7,17 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import CodeScannerConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
ALLOWED_CODE_SCANNER_MODEL_IDS = [
|
||||||
"CodeScanner",
|
"CodeScanner",
|
||||||
"CodeShield",
|
"CodeShield",
|
||||||
|
@ -48,7 +52,7 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
|
|
||||||
from codeshield.cs import CodeShield
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
text = "\n".join([interleaved_text_media_as_str(m.content) for m in messages])
|
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
||||||
log.info(f"Running CodeScannerShield on {text[50:]}")
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
||||||
result = await CodeShield.scan_code(text)
|
result = await CodeShield.scan_code(text)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
|
@ -10,7 +10,6 @@ from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
|
@ -10,7 +10,6 @@ from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
|
|
@ -11,7 +11,6 @@ import httpx
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
|
@ -90,7 +89,7 @@ model_aliases = [
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias_with_just_provider_model_id(
|
build_model_alias_with_just_provider_model_id(
|
||||||
"llama3.2-vision",
|
"llama3.2-vision:latest",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
|
|
|
@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -267,7 +267,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
|
@ -8,7 +8,6 @@ import logging
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image as PIL_Image
|
|
||||||
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -17,6 +16,9 @@ from .utils import group_chunks
|
||||||
|
|
||||||
THIS_DIR = Path(__file__).parent
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
|
||||||
|
PASTA_IMAGE = f.read()
|
||||||
|
|
||||||
|
|
||||||
class TestVisionModelInference:
|
class TestVisionModelInference:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -24,12 +26,12 @@ class TestVisionModelInference:
|
||||||
"image, expected_strings",
|
"image, expected_strings",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
|
ImageContentItem(data=PASTA_IMAGE),
|
||||||
["spaghetti"],
|
["spaghetti"],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ImageMedia(
|
ImageContentItem(
|
||||||
image=URL(
|
data=URL(
|
||||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
@ -58,7 +60,12 @@ class TestVisionModelInference:
|
||||||
model_id=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content="You are a helpful assistant."),
|
UserMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content=[image, "Describe this image in two sentences."]),
|
UserMessage(
|
||||||
|
content=[
|
||||||
|
image,
|
||||||
|
TextContentItem(text="Describe this image in two sentences."),
|
||||||
|
]
|
||||||
|
),
|
||||||
],
|
],
|
||||||
stream=False,
|
stream=False,
|
||||||
sampling_params=SamplingParams(max_tokens=100),
|
sampling_params=SamplingParams(max_tokens=100),
|
||||||
|
@ -89,8 +96,8 @@ class TestVisionModelInference:
|
||||||
)
|
)
|
||||||
|
|
||||||
images = [
|
images = [
|
||||||
ImageMedia(
|
ImageContentItem(
|
||||||
image=URL(
|
data=URL(
|
||||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
@ -106,7 +113,12 @@ class TestVisionModelInference:
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content="You are a helpful assistant."),
|
UserMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content=[image, "Describe this image in two sentences."]
|
content=[
|
||||||
|
image,
|
||||||
|
TextContentItem(
|
||||||
|
text="Describe this image in two sentences."
|
||||||
|
),
|
||||||
|
]
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
from llama_stack.apis.datasets import DatasetInput
|
from llama_stack.apis.datasets import DatasetInput
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from urllib.parse import unquote
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,11 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
from llama_stack.apis.inference import (
|
||||||
|
EmbeddingsResponse,
|
||||||
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
InterleavedContent,
|
||||||
|
ModelStore,
|
||||||
|
)
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
@ -23,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(
|
embedding_model = self._load_sentence_transformer_model(
|
||||||
|
|
|
@ -93,11 +93,15 @@ def process_chat_completion_response(
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|
||||||
completion_message = formatter.decode_assistant_message_from_content(
|
raw_message = formatter.decode_assistant_message_from_content(
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||||
)
|
)
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=completion_message,
|
completion_message=CompletionMessage(
|
||||||
|
content=raw_message.content,
|
||||||
|
stop_reason=raw_message.stop_reason,
|
||||||
|
tool_calls=raw_message.tool_calls,
|
||||||
|
),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
@ -21,7 +22,6 @@ from llama_models.llama3.api.datatypes import (
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
Role,
|
Role,
|
||||||
ToolChoice,
|
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.prompt_templates import (
|
from llama_models.llama3.prompt_templates import (
|
||||||
|
@ -47,6 +47,7 @@ from llama_stack.apis.inference import (
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
|
ToolChoice,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -136,7 +137,7 @@ def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||||
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||||
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
if isinstance(media.data, URL) and media.data.uri.startswith("http"):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(media.image.uri)
|
r = await client.get(media.data.uri)
|
||||||
content = r.content
|
content = r.content
|
||||||
content_type = r.headers.get("content-type")
|
content_type = r.headers.get("content-type")
|
||||||
if content_type:
|
if content_type:
|
||||||
|
@ -145,7 +146,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||||
format = "png"
|
format = "png"
|
||||||
return content, format
|
return content, format
|
||||||
else:
|
else:
|
||||||
image = PIL_Image.open(media.data)
|
image = PIL_Image.open(io.BytesIO(media.data))
|
||||||
return media.data, image.format
|
return media.data, image.format
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,7 +154,7 @@ async def convert_image_content_to_url(
|
||||||
media: ImageContentItem, download: bool = False, include_format: bool = True
|
media: ImageContentItem, download: bool = False, include_format: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(media.data, URL) and not download:
|
if isinstance(media.data, URL) and not download:
|
||||||
return media.image.uri
|
return media.data.uri
|
||||||
|
|
||||||
content, format = await localize_image_content(media)
|
content, format = await localize_image_content(media)
|
||||||
if include_format:
|
if include_format:
|
||||||
|
|
|
@ -8,7 +8,7 @@ import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from llama_stack.apis.common.deployment_types import URL
|
||||||
|
|
||||||
|
|
||||||
def data_url_from_file(file_path: str) -> URL:
|
def data_url_from_file(file_path: str) -> URL:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue