Fixed imports for inference (#661)

# What does this PR do?

In short, provide a summary of what this PR does and why. Usually, the
relevant context should be present in a linked issue.

- [x] Addresses issue (#issue)
```
    from .nvidia import NVIDIAInferenceAdapter
  File "/localhome/local-cdgamarose/llama-stack/llama_stack/providers/remote/inference/nvidia/nvidia.py", line 37, in <module>
    from .openai_utils import (
  File "/localhome/local-cdgamarose/llama-stack/llama_stack/providers/remote/inference/nvidia/openai_utils.py", line 11, in <module>
    from llama_models.llama3.api.datatypes import (
ImportError: cannot import name 'CompletionMessage' from 'llama_models.llama3.api.datatypes' (/localhome/local-cdgamarose/.local/lib/python3.10/site-packages/llama_models/llama3/api/datatypes.py)
++ error_handler 62
```

## Test Plan
Deploy NIM using docker from
https://build.nvidia.com/meta/llama-3_1-8b-instruct?snippet_tab=Docker
```
(lsmyenv) local-cdgamarose@a4u8g-0006:~/llama-stack$ python3 -m pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_BASE_URL=http://localhost:8000 -k test_completion --inference-model Llama3.1-8B-Instruct
======================================================================================== test session starts =========================================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /localhome/local-cdgamarose/anaconda3/envs/lsmyenv/bin/python3
cachedir: .pytest_cache
rootdir: /localhome/local-cdgamarose/llama-stack
configfile: pyproject.toml
plugins: anyio-4.7.0, asyncio-0.25.0
asyncio: mode=strict, asyncio_default_fixture_loop_scope=None
collected 24 items / 21 deselected / 3 selected

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-nvidia] Initializing NVIDIAInferenceAdapter(http://localhost:8000)...
Checking NVIDIA NIM health...
Checking NVIDIA NIM health...
PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-nvidia] SKIPPED (Other inference providers don't support completion() yet)
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-nvidia] SKIPPED (This test is not quite robust)

====================================================================== 1 passed, 2 skipped, 21 deselected, 2 warnings in 1.57s =======================================================================
```

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Ran pre-commit to handle lint / formatting issues.
- [x] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [x] Wrote necessary unit or integration tests.
This commit is contained in:
cdgamarose-nv 2024-12-19 14:19:36 -08:00 committed by GitHub
parent 540fc4d717
commit ddf37ea467
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,9 +10,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
CompletionMessage,
StopReason, StopReason,
TokenLogProbs,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
) )
@ -42,12 +40,14 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
SystemMessage, SystemMessage,
TokenLogProbs,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
ToolResponseMessage, ToolResponseMessage,