Tests pass with Ollama now

This commit is contained in:
Ashwin Bharambe 2024-12-15 17:31:21 -08:00
parent a9a041a1de
commit e51154964f
27 changed files with 83 additions and 65 deletions

View file

@ -59,7 +59,7 @@ class MemoryRouter(Memory):
async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
return await self.routing_table.get_provider_impl(bank_id).query_documents(
@ -133,7 +133,7 @@ class InferenceRouter(Inference):
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
@ -163,7 +163,7 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:

View file

@ -16,8 +16,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.deployment_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.distribution.store import DistributionRegistry
@ -30,7 +29,6 @@ def get_impl_api(p: Any) -> Api:
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
@ -76,7 +74,6 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None: