Skip to content

Commit 428dba2

Browse files
hkernbachaMahanna
andauthored
GA-168 GPU Test (#43)
* added gpu test using nx and cuda, first test commit for circleci - this is expected to fail * fix yml formatting * fix yml formatting again, define executr gpu * add test-gpu to matrix executor * fix resource class, added todo for later * flake8 * pot deps fix * gpu test enable * gpu test enable * fix syntax * fix test, should work now on ci as well * incr grid of graph * restructured test dirs, do not automatically run gpu tests. * isort * fmt, move test code * this is not allowed to be removed * fmt * test * 3.12 instead of 3.12.2 for gpu * new: `use_gpu` backend config * attempt: set `use_gpu` * force-set `use_gpu` * fix: lint * cleanup * fix: lint * fix imports * attempt: increase `digit` * new: `write_async` param * move assertions * fix lint ffs... * attempt: increase `digit` --------- Co-authored-by: Anthony Mahanna <[email protected]>
1 parent 93d1d24 commit 428dba2

File tree

10 files changed

+165
-16
lines changed

10 files changed

+165
-16
lines changed

.circleci/config.yml

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ executors:
1212
machine:
1313
image: ubuntu-2404:current
1414

15+
gpu-executor:
16+
machine:
17+
image: linux-cuda-12:default
18+
resource_class: gpu.nvidia.small.multi
19+
1520
jobs:
1621
lint:
1722
executor: python-executor
@@ -79,6 +84,45 @@ jobs:
7984
name: Run NetworkX tests
8085
command: ./run_nx_tests.sh
8186

87+
test-gpu:
88+
parameters:
89+
python_version:
90+
type: string
91+
executor: gpu-executor
92+
steps:
93+
- checkout
94+
95+
- run:
96+
name: Set up ArangoDB
97+
command: |
98+
chmod +x starter.sh
99+
./starter.sh
100+
101+
- run:
102+
name: Setup Python
103+
command: |
104+
pyenv --version
105+
pyenv install -f << parameters.python_version >>
106+
pyenv global << parameters.python_version >>
107+
108+
- run:
109+
name: Setup pip
110+
command: python -m pip install --upgrade pip setuptools wheel
111+
112+
- run:
113+
name: Install packages
114+
command: pip install .[dev]
115+
116+
- run:
117+
name: Install cuda related dependencies
118+
command: |
119+
pip install pylibcugraph-cu12 --extra-index-url https://pypi.nvidia.com
120+
pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com
121+
122+
- run:
123+
name: Run local gpu tests
124+
command: pytest tests/test.py -k "test_gpu" --run-gpu-tests
125+
82126
workflows:
83127
version: 2
84128
build:
@@ -87,4 +131,11 @@ workflows:
87131
- test:
88132
matrix:
89133
parameters:
90-
python_version: ["3.10", "3.11", "3.12.2"]
134+
python_version: ["3.10", "3.11", "3.12.2"]
135+
- test-gpu:
136+
requires:
137+
- lint
138+
- test
139+
matrix:
140+
parameters:
141+
python_version: ["3.10", "3.11"] # "3.12" # TODO: Revisit 3.12

_nx_arangodb/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def get_info():
8282
"read_parallelism": None,
8383
"read_batch_size": None,
8484
"write_batch_size": None,
85+
"use_gpu": True,
8586
}
8687

8788
return d

nx_arangodb/classes/digraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
read_parallelism: int = 10,
3535
read_batch_size: int = 100000,
3636
write_batch_size: int = 50000,
37+
write_async: bool = True,
3738
symmetrize_edges: bool = False,
3839
use_experimental_views: bool = False,
3940
*args: Any,
@@ -50,6 +51,7 @@ def __init__(
5051
read_parallelism,
5152
read_batch_size,
5253
write_batch_size,
54+
write_async,
5355
symmetrize_edges,
5456
use_experimental_views,
5557
*args,

nx_arangodb/classes/graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from typing import Any, Callable, ClassVar
44

55
import networkx as nx
6-
import numpy as np
7-
import numpy.typing as npt
86
from adbnx_adapter import ADBNX_Adapter
97
from arango import ArangoClient
108
from arango.cursor import Cursor
@@ -57,6 +55,7 @@ def __init__(
5755
read_parallelism: int = 10,
5856
read_batch_size: int = 100000,
5957
write_batch_size: int = 50000,
58+
write_async: bool = True,
6059
symmetrize_edges: bool = False,
6160
use_experimental_views: bool = False,
6261
*args: Any,
@@ -168,7 +167,7 @@ def edge_type_func(u: str, v: str) -> str:
168167
incoming_graph_data,
169168
edge_definitions=edge_definitions,
170169
batch_size=self.write_batch_size,
171-
use_async=True,
170+
use_async=write_async,
172171
)
173172

174173
else:
@@ -211,6 +210,7 @@ def _set_arangodb_backend_config(self) -> None:
211210
config.read_parallelism = self.read_parallelism
212211
config.read_batch_size = self.read_batch_size
213212
config.write_batch_size = self.write_batch_size
213+
config.use_gpu = True # Only used by default if nx-cugraph is available
214214

215215
def _set_factory_methods(self) -> None:
216216
"""Set the factory methods for the graph, _node, and _adj dictionaries.

nx_arangodb/classes/multidigraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
read_parallelism: int = 10,
3434
read_batch_size: int = 100000,
3535
write_batch_size: int = 50000,
36+
write_async: bool = True,
3637
symmetrize_edges: bool = False,
3738
use_experimental_views: bool = False,
3839
*args: Any,
@@ -49,6 +50,7 @@ def __init__(
4950
read_parallelism,
5051
read_batch_size,
5152
write_batch_size,
53+
write_async,
5254
symmetrize_edges,
5355
use_experimental_views,
5456
*args,

nx_arangodb/classes/multigraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
read_parallelism: int = 10,
3535
read_batch_size: int = 100000,
3636
write_batch_size: int = 50000,
37+
write_async: bool = True,
3738
symmetrize_edges: bool = False,
3839
use_experimental_views: bool = False,
3940
*args: Any,
@@ -50,6 +51,7 @@ def __init__(
5051
read_parallelism,
5152
read_batch_size,
5253
write_batch_size,
54+
write_async,
5355
symmetrize_edges,
5456
use_experimental_views,
5557
*args,

nx_arangodb/convert.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
import numpy as np
1717
import nx_cugraph as nxcg
1818

19-
GPU_ENABLED = True
20-
logger.info("NetworkX-cuGraph is enabled.")
19+
GPU_AVAILABLE = True
20+
logger.info("NetworkX-cuGraph is available.")
2121
except Exception as e:
22-
GPU_ENABLED = False
23-
logger.info(f"NetworkX-cuGraph is disabled: {e}.")
22+
GPU_AVAILABLE = False
23+
logger.info(f"NetworkX-cuGraph is unavailable: {e}.")
2424

2525
__all__ = [
2626
"_to_nx_graph",
@@ -58,7 +58,7 @@ def _to_nxadb_graph(
5858
raise TypeError(f"Expected nxadb.Graph or nx.Graph; got {type(G)}")
5959

6060

61-
if GPU_ENABLED:
61+
if GPU_AVAILABLE:
6262

6363
def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph:
6464
logger.debug(f"_to_nxcg_graph for {G.__class__.__name__}")
@@ -161,7 +161,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
161161
return G_NX
162162

163163

164-
if GPU_ENABLED:
164+
if GPU_AVAILABLE:
165165

166166
def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
167167
if G.use_nxcg_cache and G.nxcg_graph is not None:

nx_arangodb/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _auto_func(func_name: str, /, *args: Any, **kwargs: Any) -> Any:
6464

6565
# TODO: Use `nx.config.backends.arangodb.backend_priority` instead
6666
backend_priority = []
67-
if nxadb.convert.GPU_ENABLED:
67+
if nxadb.convert.GPU_AVAILABLE and nx.config.backends.arangodb.use_gpu:
6868
backend_priority.append("cugraph")
6969

7070
for backend in backend_priority:

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
import os
3+
import sys
4+
from io import StringIO
35
from typing import Any
46

57
import networkx as nx
@@ -14,13 +16,17 @@
1416
logger.setLevel(logging.INFO)
1517

1618
db: StandardDatabase
19+
run_gpu_tests: bool
1720

1821

1922
def pytest_addoption(parser: Any) -> None:
2023
parser.addoption("--url", action="store", default="http://localhost:8529")
2124
parser.addoption("--dbName", action="store", default="_system")
2225
parser.addoption("--username", action="store", default="root")
2326
parser.addoption("--password", action="store", default="test")
27+
parser.addoption(
28+
"--run-gpu-tests", action="store_true", default=False, help="Run GPU tests"
29+
)
2430

2531

2632
def pytest_configure(config: Any) -> None:
@@ -48,6 +54,9 @@ def pytest_configure(config: Any) -> None:
4854
os.environ["DATABASE_PASSWORD"] = con["password"]
4955
os.environ["DATABASE_NAME"] = con["dbName"]
5056

57+
global run_gpu_tests
58+
run_gpu_tests = config.getoption("--run-gpu-tests")
59+
5160

5261
@pytest.fixture(scope="function")
5362
def load_karate_graph() -> None:
@@ -100,3 +109,28 @@ def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
100109
name="LineGraph",
101110
edge_collections_attributes=load_attributes,
102111
)
112+
113+
114+
def create_grid_graph(graph_cls: type[nxadb.Graph]) -> nxadb.Graph:
115+
global db
116+
if db.has_graph("GridGraph"):
117+
return graph_cls(name="GridGraph")
118+
119+
grid_graph = nx.grid_graph(dim=(500, 500))
120+
return graph_cls(
121+
incoming_graph_data=grid_graph, name="GridGraph", write_async=False
122+
)
123+
124+
125+
# Taken from:
126+
# https://stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call
127+
class Capturing(list[str]):
128+
def __enter__(self):
129+
self._stdout = sys.stdout
130+
sys.stdout = self._stringio = StringIO()
131+
return self
132+
133+
def __exit__(self, *args):
134+
self.extend(self._stringio.getvalue().splitlines())
135+
del self._stringio # free up some memory
136+
sys.stdout = self._stdout

tests/test.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import time
12
from typing import Any, Callable, Dict, Union
23

34
import networkx as nx
4-
import phenolrs
55
import pytest
66
from arango import DocumentDeleteError
77
from phenolrs.networkx.typings import (
@@ -15,7 +15,7 @@
1515
from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict, EdgeKeyDict
1616
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict
1717

18-
from .conftest import create_line_graph, db
18+
from .conftest import Capturing, create_grid_graph, create_line_graph, db, run_gpu_tests
1919

2020
G_NX = nx.karate_club_graph()
2121

@@ -38,7 +38,11 @@ def assert_same_dict_values(
3838
if type(next(iter(d2.keys()))) == int:
3939
d2 = {f"person/{k}": v for k, v in d2.items()}
4040

41-
assert d1.keys() == d2.keys(), "Dictionaries have different keys"
41+
d1_keys = set(d1.keys())
42+
d2_keys = set(d2.keys())
43+
difference = d1_keys ^ d2_keys
44+
assert difference == set(), "Dictionaries have different keys"
45+
4246
for key in d1:
4347
m = f"Values for key '{key}' are not equal up to digit {digit}"
4448
assert round(d1[key], digit) == round(d2[key], digit), m
@@ -50,10 +54,12 @@ def assert_bc(d1: dict[str | int, float], d2: dict[str | int, float]) -> None:
5054
assert_same_dict_values(d1, d2, 14)
5155

5256

53-
def assert_pagerank(d1: dict[str | int, float], d2: dict[str | int, float]) -> None:
57+
def assert_pagerank(
58+
d1: dict[str | int, float], d2: dict[str | int, float], digit: int = 15
59+
) -> None:
5460
assert d1
5561
assert d2
56-
assert_same_dict_values(d1, d2, 15)
62+
assert_same_dict_values(d1, d2, digit)
5763

5864

5965
def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None:
@@ -315,6 +321,57 @@ def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None:
315321
assert r_3 != r_4
316322

317323

324+
@pytest.mark.parametrize(
325+
"graph_cls",
326+
[
327+
(nxadb.Graph),
328+
(nxadb.DiGraph),
329+
(nxadb.MultiGraph),
330+
(nxadb.MultiDiGraph),
331+
],
332+
)
333+
def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None:
334+
if not run_gpu_tests:
335+
pytest.skip("GPU tests are disabled")
336+
337+
graph = create_grid_graph(graph_cls)
338+
339+
assert nxadb.convert.GPU_AVAILABLE is True
340+
assert nx.config.backends.arangodb.use_gpu is True
341+
342+
res_gpu = None
343+
res_cpu = None
344+
345+
# Measure GPU execution time
346+
start_gpu = time.time()
347+
348+
# Note: While this works, we should use the logger or some alternative
349+
# approach testing this. Via stdout is not the best way to test this.
350+
with Capturing() as output_gpu:
351+
res_gpu = nx.pagerank(graph)
352+
353+
assert any(
354+
"NXCG Graph construction took" in line for line in output_gpu
355+
), "Expected output not found in GPU execution"
356+
357+
gpu_time = time.time() - start_gpu
358+
359+
# Disable GPU and measure CPU execution time
360+
nx.config.backends.arangodb.use_gpu = False
361+
start_cpu = time.time()
362+
with Capturing() as output_cpu:
363+
res_cpu = nx.pagerank(graph)
364+
365+
output_cpu_list = list(output_cpu)
366+
assert len(output_cpu_list) == 1
367+
assert "Graph 'GridGraph' load took" in output_cpu_list[0]
368+
369+
cpu_time = time.time() - start_cpu
370+
371+
assert gpu_time < cpu_time, "GPU execution should be faster than CPU execution"
372+
assert_pagerank(res_gpu, res_cpu, 10)
373+
374+
318375
@pytest.mark.parametrize(
319376
"graph_cls",
320377
[

0 commit comments

Comments
 (0)