diff --git a/scp.py b/scp.py index c4f2499..995c606 100644 --- a/scp.py +++ b/scp.py @@ -120,7 +120,7 @@ def __exit__(self, type, value, traceback): self.close() def put(self, files, remote_path=b'.', - recursive=False, preserve_times=False): + recursive=False, preserve_times=False, limit_bandwidth=0): """ Transfer files to remote host. @@ -135,12 +135,17 @@ def put(self, files, remote_path=b'.', @param preserve_times: preserve mtime and atime of transfered files and directories. @type preserve_times: bool + @param limit_bandwidth: Limits the used bandwidth, specified in Kbit/s. + @type limit_bandwidth: int """ self.preserve_times = preserve_times self.channel = self._open() self._pushed = 0 self.channel.settimeout(self.socket_timeout) - scp_command = (b'scp -t ', b'scp -r -t ')[recursive] + scp_command = b'scp -t ' + scp_command += (b'', b' -r ')[recursive] + if limit_bandwidth > 0: + scp_command += b' -l ' + str(limit_bandwidth).encode() + b' ' self.channel.exec_command(scp_command + self.sanitize(asbytes(remote_path))) self._recv_confirm() @@ -156,7 +161,7 @@ def put(self, files, remote_path=b'.', self.close() def get(self, remote_path, local_path='', - recursive=False, preserve_times=False): + recursive=False, preserve_times=False, limit_bandwidth=0): """ Transfer files from remote host to localhost @@ -171,6 +176,8 @@ def get(self, remote_path, local_path='', @param preserve_times: preserve mtime and atime of transfered files and directories. @type preserve_times: bool + @param limit_bandwidth: Limits the used bandwidth, specified in Kbit/s. + @type limit_bandwidth: int """ if not isinstance(remote_path, (list, tuple)): remote_path = [remote_path] @@ -187,12 +194,14 @@ def get(self, remote_path, local_path='', asunicode(self._recv_dir)) rcsv = (b'', b' -r')[recursive] prsv = (b'', b' -p')[preserve_times] + lmbw = (b'', b' -l '+ str(limit_bandwidth).encode())[limit_bandwidth > 0] self.channel = self._open() self._pushed = 0 self.channel.settimeout(self.socket_timeout) self.channel.exec_command(b"scp" + rcsv + prsv + + lmbw + b" -f " + b' '.join(remote_path)) self._recv_all() diff --git a/test.py b/test.py index a9787f5..10e0720 100644 --- a/test.py +++ b/test.py @@ -299,6 +299,24 @@ def test_up_and_down(self): finally: os.chdir(previous) + def test_up_and_down_with_limit(self): + '''send and receive files with the same client''' + previous = os.getcwd() + testfile = os.path.join(self._temp, 'testfile') + testfile_sent = os.path.join(self._temp, 'testfile_sent') + testfile_rcvd = os.path.join(self._temp, 'testfile_rcvd') + try: + os.chdir(self._temp) + with open(testfile, 'w') as f: + f.write("TESTING\n") + with SCPClient(self.ssh.get_transport(), socket_timeout=60.0) as scp: + scp.put(testfile, testfile_sent, limit_bandwidth=10) + scp.get(testfile_sent, testfile_rcvd, limit_bandwidth=10) + + assert open(testfile_rcvd).read() == 'TESTING\n' + finally: + os.chdir(previous) + if __name__ == '__main__': unittest.main()