Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions testgres/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def spawn_primary(self, name=None, destroy=True):

return node

def spawn_replica(self, name=None, destroy=True):
def spawn_replica(self, name=None, destroy=True, slot_name=None):
"""
Create a replica of the original node from a backup.

Expand All @@ -171,7 +171,7 @@ def spawn_replica(self, name=None, destroy=True):

# Assign it a master and a recovery file (private magic)
node._assign_master(self.original_node)
node._create_recovery_conf(username=self.username)
node._create_recovery_conf(username=self.username, slot_name=slot_name)

return node

Expand Down
3 changes: 3 additions & 0 deletions testgres/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@
PG_LOG_FILE = "postgresql.log"
UTILS_LOG_FILE = "utils.log"
BACKUP_LOG_FILE = "backup.log"

# default replication slots number
REPLICATION_SLOTS = 10
34 changes: 30 additions & 4 deletions testgres/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
RECOVERY_CONF_FILE, \
PG_LOG_FILE, \
UTILS_LOG_FILE, \
PG_PID_FILE
PG_PID_FILE, \
REPLICATION_SLOTS

from .decorators import \
method_decorator, \
Expand Down Expand Up @@ -179,7 +180,7 @@ def _assign_master(self, master):
# now this node has a master
self._master = master

def _create_recovery_conf(self, username):
def _create_recovery_conf(self, username, slot_name=None):
"""NOTE: this is a private method!"""

# fetch master of this node
Expand Down Expand Up @@ -207,6 +208,9 @@ def _create_recovery_conf(self, username):
"standby_mode=on\n"
).format(conninfo)

if slot_name:
line += "primary_slot_name={}\n".format(slot_name)

self.append_conf(RECOVERY_CONF_FILE, line)

def _maybe_start_logger(self):
Expand Down Expand Up @@ -250,6 +254,21 @@ def _collect_special_files(self):

return result

def _create_replication_slot(self, slot_name, dbname=None, username=None):
"""
Create a physical replication slot.

Args:
slot_name: slot name
dbname: database name
username: database user name
"""
query = (
"select pg_create_physical_replication_slot('{}')"
).format(slot_name)

self.execute(query=query, dbname=dbname, username=username)

def init(self, initdb_params=None, **kwargs):
"""
Perform initdb for this node.
Expand Down Expand Up @@ -360,8 +379,10 @@ def get_auth_method(t):
wal_keep_segments = 20 # for convenience
conf.write(u"hot_standby = on\n"
u"max_wal_senders = {}\n"
u"max_replication_slots = {}\n"
u"wal_keep_segments = {}\n"
u"wal_level = {}\n".format(max_wal_senders,
REPLICATION_SLOTS,
wal_keep_segments,
wal_level))

Expand Down Expand Up @@ -856,7 +877,7 @@ def backup(self, **kwargs):

return NodeBackup(node=self, **kwargs)

def replicate(self, name=None, **kwargs):
def replicate(self, name=None, slot_name=None, **kwargs):
"""
Create a binary replica of this node.

Expand All @@ -867,10 +888,15 @@ def replicate(self, name=None, **kwargs):
base_dir: the base directory for data files and logs
"""

if slot_name:
self._create_replication_slot(slot_name, **kwargs)

backup = self.backup(**kwargs)

# transform backup into a replica
return backup.spawn_replica(name=name, destroy=True)
return backup.spawn_replica(name=name,
destroy=True,
slot_name=slot_name)

def catchup(self, dbname=None, username=None):
"""
Expand Down
11 changes: 11 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,17 @@ def test_replicate(self):
res = node.execute('select * from test')
self.assertListEqual(res, [])

def test_replication_slots(self):
query_create = 'create table test as select generate_series(1, 2) as val'

with get_new_node() as node:
node.init(allow_streaming=True).start()
node.execute(query_create)

with node.replicate(slot_name='slot1').start() as replica:
res = replica.execute('select * from test')
self.assertListEqual(res, [(1, ), (2, )])

def test_incorrect_catchup(self):
with get_new_node() as node:
node.init(allow_streaming=True).start()
Expand Down