From 1cb42d306019676ad60fe7a0737b20842fd1642f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 21:47:39 -0800 Subject: [PATCH] rebase and fix some small breakage due to model -> model_id fix --- llama_stack/distribution/routers/routing_tables.py | 10 +++++----- .../inline/safety/llama_guard/llama_guard.py | 2 +- .../providers/remote/inference/ollama/ollama.py | 1 - llama_stack/providers/tests/agents/conftest.py | 12 +++++++++++- .../providers/tests/inference/test_text_inference.py | 2 +- .../tests/inference/test_vision_inference.py | 4 ++-- 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7b7433862..5342728b1 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -38,15 +38,15 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable if api == Api.inference: return await p.register_model(obj) elif api == Api.safety: - await p.register_shield(obj) + return await p.register_shield(obj) elif api == Api.memory: - await p.register_memory_bank(obj) + return await p.register_memory_bank(obj) elif api == Api.datasetio: - await p.register_dataset(obj) + return await p.register_dataset(obj) elif api == Api.scoring: - await p.register_scoring_function(obj) + return await p.register_scoring_function(obj) elif api == Api.eval: - await p.register_eval_task(obj) + return await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 494c1b43e..9950064a4 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -234,7 +234,7 @@ class LlamaGuardShield: # TODO: llama-stack inference protocol has issues with non-streaming inference code content = "" async for chunk in await self.inference_api.chat_completion( - model=self.model, + model_id=self.model, messages=[shield_input_message], stream=True, ): diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 99f74572e..3a32125b2 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) - print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index c4f766e26..6ce7913d7 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="together", marks=pytest.mark.together, ), + pytest.param( + { + "inference": "fireworks", + "safety": "llama_guard", + "memory": "faiss", + "agents": "meta_reference", + }, + id="fireworks", + marks=pytest.mark.fireworks, + ), pytest.param( { "inference": "remote", @@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together", "remote"]: + for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 70047a61f..7b7aca5bd 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -147,9 +147,9 @@ class TestInference: user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." response = await inference_impl.completion( + model_id=inference_model, content=user_input, stream=False, - model=inference_model, sampling_params=SamplingParams( max_tokens=50, ), diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 3e785b757..c5db04cca 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -55,7 +55,7 @@ class TestVisionModelInference: ) response = await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=[ UserMessage(content="You are a helpful assistant."), UserMessage(content=[image, "Describe this image in two sentences."]), @@ -102,7 +102,7 @@ class TestVisionModelInference: response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=[ UserMessage(content="You are a helpful assistant."), UserMessage(