From c31a533c470e4d421f9099d20e2556f014b97608 Mon Sep 17 00:00:00 2001 From: Yehuda Anikster Date: Tue, 24 Jan 2023 16:23:18 +0200 Subject: [PATCH 1/3] Added getfo and get_data functions to SCPClient, added tests, documentation in README --- README.rst | 28 +++++++++++++++ scp.py | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++++- test.py | 21 +++++++++++ 3 files changed, 150 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 936ac2a..8239ce4 100644 --- a/README.rst +++ b/README.rst @@ -91,6 +91,34 @@ The ``putfo`` method can be used to upload file-like objects: fl.close() +Downloading to file-like objects or buffers +--------------------------- + +The ``getfo`` method can be used to download a single file into a BytesIO (file-like) stream. +Likewise, the ``get_data`` method can be used to download a single file's data and get it as a bytes/str object. + +.. code-block:: python + + import io + from paramiko import SSHClient + from scp import SCPClient + + ssh = SSHClient() + ssh.load_system_host_keys() + ssh.connect('example.com') + + # SCPCLient takes a paramiko transport as an argument + scp = SCPClient(ssh.get_transport()) + + # Download a file's data into a BytesIO object + fl = scp.getfo('/tmp/test.txt') + print(fl.getvalue()) + + # Or get the data directly + data = scp.get_data('/tmp/test.txt', decode_utf8=True) + print(data) + + Tracking progress of your file uploads/downloads ------------------------------------------------ diff --git a/scp.py b/scp.py index 6431c45..cecc7fc 100644 --- a/scp.py +++ b/scp.py @@ -11,6 +11,8 @@ import os import re from socket import timeout as SocketTimeout +import ntpath +from io import BytesIO SCP_COMMAND = b'scp' @@ -29,7 +31,7 @@ pass try: - from typing import IO, TYPE_CHECKING, AnyStr, Callable, Iterable, Optional, Tuple, Union + from typing import IO, TYPE_CHECKING, AnyStr, Callable, Iterable, Optional, Tuple, Union, Dict if TYPE_CHECKING: import paramiko.transport @@ -159,6 +161,7 @@ def __init__(self, transport, buff_size=16384, socket_timeout=10.0, self._dirtimes = {} self.peername = self.transport.getpeername() self.scp_command = SCP_COMMAND + self.remote_file_name = None def __enter__(self): self.channel = self._open() @@ -281,6 +284,45 @@ def get(self, remote_path, local_path='', self._recv_all() self.close() + def getfo(self, remote_path): + # type: (PathTypes) -> BytesIO + """ + Transfer a file from remote host to localhost to a file-like object. + + @param remote_path: path to retrieve from remote host. Note that + wildcards will be escaped unless you changed the `sanitize` + function. + @type remote_path: str + @returns: file-like object with the file's data + """ + remote_path = self.sanitize(asbytes(remote_path)) + self.channel = self._open() + self.remote_file_name = ntpath.basename(asunicode_win(remote_path)) + self.channel.settimeout(self.socket_timeout) + self.channel.exec_command(self.scp_command + b" -f " + remote_path) + bytes_io = BytesIO() + self._recv_all_fo(bytes_io) + self.close() + return bytes_io + + def get_data(self, remote_path, decode_utf8=False): + # type: (PathTypes, bool) -> Union[bytes, str] + """ + Transfer a file from remote host to localhost, return the data of the remote file. + + @param remote_path: path to retrieve from remote host. Note that + wildcards will be escaped unless you changed the `sanitize` + function. + @type remote_path: str + @param decode_utf8: should decode result as utf-8 + @type decode_utf8: bool + @returns: data with the remote file's data + """ + data = self.getfo(remote_path).getvalue() + if decode_utf8: + return data.decode("utf-8") + return data + def _open(self): """open a scp channel""" if self.channel is None or self.channel.closed: @@ -553,6 +595,64 @@ def _recv_popd(self, *cmd): self._depth -= 1 self._recv_dir = os.path.split(self._recv_dir)[0] + def _recv_all_fo(self, fh): + # type: (BytesIO) -> None + # loop over scp commands, and receive as necessary + commands = (b'C', ) + while not self.channel.closed: + # wait for command as long as we're open + self.channel.sendall('\x00') + msg = self.channel.recv(1024) + if not msg: # chan closed while receiving + break + assert msg[-1:] == b'\n' + msg = msg[:-1] + code = msg[0:1] + if code not in commands: + raise SCPException(asunicode(msg[1:])) + self._recv_file_fo(msg[1:], fh) + + def _recv_file_fo(self, cmd, fh): + # type: (bytes, BytesIO) -> None + chan = self.channel + parts = cmd.strip().split(b' ', 2) + + try: + mode = int(parts[0], 8) + size = int(parts[1]) + except: + chan.send('\x01') + chan.close() + raise SCPException('Bad file format') + + if self._progress: + if size == 0: + # avoid divide-by-zero + self._progress(self.remote_file_name, 1, 1, self.peername) + else: + self._progress(self.remote_file_name, size, 0, self.peername) + buff_size = self.buff_size + pos = 0 + chan.send(b'\x00') + try: + while pos < size: + # we have to make sure we don't read the final byte + if size - pos <= buff_size: + buff_size = size - pos + data = chan.recv(buff_size) + if not data: + raise SCPException("Underlying channel was closed") + fh.write(data) + pos = fh.tell() + if self._progress: + self._progress(self.remote_file_name, size, pos, self.peername) + msg = chan.recv(512) + if msg and msg[0:1] != b'\x00': + raise SCPException(asunicode(msg[1:])) + except SocketTimeout: + chan.close() + raise SCPException('Error receiving, socket.timeout') + def _set_dirtimes(self): try: for d in self._dirtimes: diff --git a/test.py b/test.py index 09ea6b6..ced68aa 100644 --- a/test.py +++ b/test.py @@ -131,6 +131,22 @@ def listdir(path, fpath): os.chdir(previous) shutil.rmtree(temp) + def download_test_fo(self, filename, result=b''): + cb3 = lambda filename, size, sent: None + with SCPClient(self.ssh.get_transport(), progress=cb3) as scp: + bytes_io = scp.getfo(filename) + value = bytes_io.getvalue() + + self.assertEqual(value, result) + + def download_test_data(self, filename, decode_utf8=False, result=b''): + cb3 = lambda filename, size, sent: None + with SCPClient(self.ssh.get_transport(), progress=cb3) as scp: + value = scp.get_data(filename, decode_utf8=decode_utf8) + if not decode_utf8: + self.assertEqual(value, result) + self.assertEqual(value, result.decode('utf-8')) + def test_get_bytes(self): self.download_test(b'/tmp/r\xC3\xA9mi', False, b'target', [u'target'], [b'target']) @@ -167,6 +183,11 @@ def test_get_folder(self): [b'target', b'target/file', b'target/b\xC3\xA8te']) + def test_get_fo(self): + self.download_test_fo(b'/tmp/r\xC3\xA9mi', result=b'') + self.download_test_data(b'/tmp/r\xC3\xA9mi', decode_utf8=False, result=b'') + self.download_test_data(b'/tmp/r\xC3\xA9mi', decode_utf8=True, result=b'') + def test_get_invalid_unicode(self): self.download_test(b'/tmp/p\xE9t\xE9', False, u'target', [u'target'], [b'target']) From ca428f9bd6f60e275572af97bd5b201b4e48313a Mon Sep 17 00:00:00 2001 From: Yehuda Anikster Date: Tue, 24 Jan 2023 16:26:02 +0200 Subject: [PATCH 2/3] Removed unused import --- scp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scp.py b/scp.py index cecc7fc..a9dfcbb 100644 --- a/scp.py +++ b/scp.py @@ -31,7 +31,7 @@ pass try: - from typing import IO, TYPE_CHECKING, AnyStr, Callable, Iterable, Optional, Tuple, Union, Dict + from typing import IO, TYPE_CHECKING, AnyStr, Callable, Iterable, Optional, Tuple, Union if TYPE_CHECKING: import paramiko.transport From d38b98ca611071ff121c09a3040a73ebf141e554 Mon Sep 17 00:00:00 2001 From: Yehuda Anikster Date: Sun, 21 May 2023 16:08:56 +0300 Subject: [PATCH 3/3] Added optional BytesIO argument to getfo and get_data --- scp.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/scp.py b/scp.py index a9dfcbb..e644052 100644 --- a/scp.py +++ b/scp.py @@ -284,8 +284,8 @@ def get(self, remote_path, local_path='', self._recv_all() self.close() - def getfo(self, remote_path): - # type: (PathTypes) -> BytesIO + def getfo(self, remote_path, bytes_io=None): + # type: (PathTypes, Optional[BytesIO]) -> BytesIO """ Transfer a file from remote host to localhost to a file-like object. @@ -293,6 +293,8 @@ def getfo(self, remote_path): wildcards will be escaped unless you changed the `sanitize` function. @type remote_path: str + @param bytes_io: Optional - a BytesIO object to write into, instead of creating one and writing into it. + @type bytes_io: BytesIO @returns: file-like object with the file's data """ remote_path = self.sanitize(asbytes(remote_path)) @@ -300,13 +302,14 @@ def getfo(self, remote_path): self.remote_file_name = ntpath.basename(asunicode_win(remote_path)) self.channel.settimeout(self.socket_timeout) self.channel.exec_command(self.scp_command + b" -f " + remote_path) - bytes_io = BytesIO() + if bytes_io is None: + bytes_io = BytesIO() self._recv_all_fo(bytes_io) self.close() return bytes_io - def get_data(self, remote_path, decode_utf8=False): - # type: (PathTypes, bool) -> Union[bytes, str] + def get_data(self, remote_path, decode_utf8=False, bytes_io=None): + # type: (PathTypes, bool, Optional[BytesIO]) -> Union[bytes, str] """ Transfer a file from remote host to localhost, return the data of the remote file. @@ -316,9 +319,11 @@ def get_data(self, remote_path, decode_utf8=False): @type remote_path: str @param decode_utf8: should decode result as utf-8 @type decode_utf8: bool + @param bytes_io: Optional - a BytesIO object to write into, instead of creating one and writing into it. + @type bytes_io: BytesIO @returns: data with the remote file's data """ - data = self.getfo(remote_path).getvalue() + data = self.getfo(remote_path, bytes_io=bytes_io).getvalue() if decode_utf8: return data.decode("utf-8") return data