Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,34 @@ The ``putfo`` method can be used to upload file-like objects:
fl.close()


Downloading to file-like objects or buffers
---------------------------

The ``getfo`` method can be used to download a single file into a BytesIO (file-like) stream.
Likewise, the ``get_data`` method can be used to download a single file's data and get it as a bytes/str object.

.. code-block:: python

import io
from paramiko import SSHClient
from scp import SCPClient

ssh = SSHClient()
ssh.load_system_host_keys()
ssh.connect('example.com')

# SCPCLient takes a paramiko transport as an argument
scp = SCPClient(ssh.get_transport())

# Download a file's data into a BytesIO object
fl = scp.getfo('/tmp/test.txt')
print(fl.getvalue())

# Or get the data directly
data = scp.get_data('/tmp/test.txt', decode_utf8=True)
print(data)


Tracking progress of your file uploads/downloads
------------------------------------------------

Expand Down
105 changes: 105 additions & 0 deletions scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import os
import re
from socket import timeout as SocketTimeout
import ntpath
from io import BytesIO


SCP_COMMAND = b'scp'
Expand Down Expand Up @@ -159,6 +161,7 @@ def __init__(self, transport, buff_size=16384, socket_timeout=10.0,
self._dirtimes = {}
self.peername = self.transport.getpeername()
self.scp_command = SCP_COMMAND
self.remote_file_name = None

def __enter__(self):
self.channel = self._open()
Expand Down Expand Up @@ -281,6 +284,50 @@ def get(self, remote_path, local_path='',
self._recv_all()
self.close()

def getfo(self, remote_path, bytes_io=None):
# type: (PathTypes, Optional[BytesIO]) -> BytesIO
"""
Transfer a file from remote host to localhost to a file-like object.

@param remote_path: path to retrieve from remote host. Note that
wildcards will be escaped unless you changed the `sanitize`
function.
@type remote_path: str
@param bytes_io: Optional - a BytesIO object to write into, instead of creating one and writing into it.
@type bytes_io: BytesIO
@returns: file-like object with the file's data
"""
remote_path = self.sanitize(asbytes(remote_path))
self.channel = self._open()
self.remote_file_name = ntpath.basename(asunicode_win(remote_path))
self.channel.settimeout(self.socket_timeout)
self.channel.exec_command(self.scp_command + b" -f " + remote_path)
if bytes_io is None:
bytes_io = BytesIO()
self._recv_all_fo(bytes_io)
self.close()
return bytes_io

def get_data(self, remote_path, decode_utf8=False, bytes_io=None):
# type: (PathTypes, bool, Optional[BytesIO]) -> Union[bytes, str]
"""
Transfer a file from remote host to localhost, return the data of the remote file.

@param remote_path: path to retrieve from remote host. Note that
wildcards will be escaped unless you changed the `sanitize`
function.
@type remote_path: str
@param decode_utf8: should decode result as utf-8
@type decode_utf8: bool
@param bytes_io: Optional - a BytesIO object to write into, instead of creating one and writing into it.
@type bytes_io: BytesIO
@returns: data with the remote file's data
"""
data = self.getfo(remote_path, bytes_io=bytes_io).getvalue()
if decode_utf8:
return data.decode("utf-8")
return data

def _open(self):
"""open a scp channel"""
if self.channel is None or self.channel.closed:
Expand Down Expand Up @@ -553,6 +600,64 @@ def _recv_popd(self, *cmd):
self._depth -= 1
self._recv_dir = os.path.split(self._recv_dir)[0]

def _recv_all_fo(self, fh):
# type: (BytesIO) -> None
# loop over scp commands, and receive as necessary
commands = (b'C', )
while not self.channel.closed:
# wait for command as long as we're open
self.channel.sendall('\x00')
msg = self.channel.recv(1024)
if not msg: # chan closed while receiving
break
assert msg[-1:] == b'\n'
msg = msg[:-1]
code = msg[0:1]
if code not in commands:
raise SCPException(asunicode(msg[1:]))
self._recv_file_fo(msg[1:], fh)

def _recv_file_fo(self, cmd, fh):
# type: (bytes, BytesIO) -> None
chan = self.channel
parts = cmd.strip().split(b' ', 2)

try:
mode = int(parts[0], 8)
size = int(parts[1])
except:
chan.send('\x01')
chan.close()
raise SCPException('Bad file format')

if self._progress:
if size == 0:
# avoid divide-by-zero
self._progress(self.remote_file_name, 1, 1, self.peername)
else:
self._progress(self.remote_file_name, size, 0, self.peername)
buff_size = self.buff_size
pos = 0
chan.send(b'\x00')
try:
while pos < size:
# we have to make sure we don't read the final byte
if size - pos <= buff_size:
buff_size = size - pos
data = chan.recv(buff_size)
if not data:
raise SCPException("Underlying channel was closed")
fh.write(data)
pos = fh.tell()
if self._progress:
self._progress(self.remote_file_name, size, pos, self.peername)
msg = chan.recv(512)
if msg and msg[0:1] != b'\x00':
raise SCPException(asunicode(msg[1:]))
except SocketTimeout:
chan.close()
raise SCPException('Error receiving, socket.timeout')

def _set_dirtimes(self):
try:
for d in self._dirtimes:
Expand Down
21 changes: 21 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ def listdir(path, fpath):
os.chdir(previous)
shutil.rmtree(temp)

def download_test_fo(self, filename, result=b''):
cb3 = lambda filename, size, sent: None
with SCPClient(self.ssh.get_transport(), progress=cb3) as scp:
bytes_io = scp.getfo(filename)
value = bytes_io.getvalue()

self.assertEqual(value, result)

def download_test_data(self, filename, decode_utf8=False, result=b''):
cb3 = lambda filename, size, sent: None
with SCPClient(self.ssh.get_transport(), progress=cb3) as scp:
value = scp.get_data(filename, decode_utf8=decode_utf8)
if not decode_utf8:
self.assertEqual(value, result)
self.assertEqual(value, result.decode('utf-8'))

def test_get_bytes(self):
self.download_test(b'/tmp/r\xC3\xA9mi', False, b'target',
[u'target'], [b'target'])
Expand Down Expand Up @@ -167,6 +183,11 @@ def test_get_folder(self):
[b'target', b'target/file',
b'target/b\xC3\xA8te'])

def test_get_fo(self):
self.download_test_fo(b'/tmp/r\xC3\xA9mi', result=b'')
self.download_test_data(b'/tmp/r\xC3\xA9mi', decode_utf8=False, result=b'')
self.download_test_data(b'/tmp/r\xC3\xA9mi', decode_utf8=True, result=b'')

def test_get_invalid_unicode(self):
self.download_test(b'/tmp/p\xE9t\xE9', False, u'target',
[u'target'], [b'target'])
Expand Down