Skip to content

Commit c31bc95

Browse files
committed
TST: Add a test for SSH function
1 parent 1f90a17 commit c31bc95

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

nipype/interfaces/tests/test_io.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from builtins import str, zip, range, open
66
from future import standard_library
77
import os
8+
import copy
89
import simplejson
910
import glob
1011
import shutil
@@ -33,6 +34,12 @@
3334
except ImportError:
3435
noboto3 = True
3536

37+
try:
38+
import paramiko
39+
no_paramiko = False
40+
except ImportError:
41+
no_paramiko = True
42+
3643
# Check for fakes3
3744
standard_library.install_aliases()
3845
from subprocess import check_call, CalledProcessError
@@ -432,5 +439,43 @@ def test_jsonsink(tmpdir, inputs_attributes):
432439

433440
assert data == expected_data
434441

435-
436-
442+
@pytest.mark.skipif(no_paramiko, reason="paramiko library is not available")
443+
def test_SSHDataGrabber(tmpdir):
444+
"""Test SSHDataGrabber by connecting to localhost and finding this test
445+
file.
446+
"""
447+
old_cwd = tmpdir.chdir()
448+
449+
# ssh client that connects to localhost, current user, regardless of
450+
# ~/.ssh/config
451+
def _mock_get_ssh_client(self):
452+
proxy = None
453+
client = paramiko.SSHClient()
454+
client.load_system_host_keys()
455+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
456+
client.connect('localhost', username=os.getenv('USER'), sock=proxy)
457+
return client
458+
MockSSHDataGrabber = copy.copy(nio.SSHDataGrabber)
459+
MockSSHDataGrabber._get_ssh_client = _mock_get_ssh_client
460+
461+
this_dir = os.path.dirname(__file__)
462+
this_file = os.path.basename(__file__)
463+
this_test = this_file[:-3] # without .py
464+
465+
ssh_grabber = MockSSHDataGrabber(infields=['test'],
466+
outfields=['test_file'])
467+
# ssh_grabber.base_dir = str(tmpdir)
468+
ssh_grabber.inputs.base_directory = this_dir
469+
ssh_grabber.inputs.hostname = 'localhost'
470+
ssh_grabber.inputs.field_template = dict(test_file='%s.py')
471+
ssh_grabber.inputs.template = ''
472+
ssh_grabber.inputs.template_args = dict(test_file=[['test']])
473+
ssh_grabber.inputs.test = this_test
474+
ssh_grabber.inputs.sort_filelist = True
475+
476+
runtime = ssh_grabber.run()
477+
478+
# did we successfully get this file?
479+
assert runtime.outputs.test_file == str(tmpdir.join(this_file))
480+
481+
old_cwd.chdir()

0 commit comments

Comments
 (0)