From 81ad9723e21825985724602e6a880cf2b511367f Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 17 Feb 2025 22:22:17 -0600 Subject: [PATCH 01/13] improve logging display at connection --- datajoint/connection.py | 13 +++++++++---- datajoint/schemas.py | 4 ++-- datajoint/settings.py | 17 ++++++++++++----- datajoint/version.py | 2 +- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 7536e7af2..26ccb540b 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -12,7 +12,7 @@ import pathlib from .settings import config -from . import errors +from . import errors, __version__ from .dependencies import Dependencies from .blob import pack, unpack from .hash import uuid_from_buffer @@ -190,15 +190,20 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None) self.conn_info["ssl_input"] = use_tls self.conn_info["host_input"] = host_input self.init_fun = init_fun - logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info)) self._conn = None self._query_cache = None connect_host_hook(self) if self.is_connected: - logger.info("Connected {user}@{host}:{port}".format(**self.conn_info)) + logger.info( + "DataJoint {version} connected {user}@{host}:{port}".format( + version=__version__, **self.conn_info + ) + ) self.connection_id = self.query("SELECT connection_id()").fetchone()[0] else: - raise errors.LostConnectionError("Connection failed.") + raise errors.LostConnectionError( + "Connection failed {user}@{host}:{port}".format(**self.conn_info) + ) self._in_transaction = False self.schemas = dict() self.dependencies = Dependencies(self) diff --git a/datajoint/schemas.py b/datajoint/schemas.py index c3894ba29..7ea40724f 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -482,8 +482,8 @@ def list_tables(self): return [ t for d, t in ( - full_t.replace("`", "").split(".") - for full_t in self.connection.dependencies.topo_sort() + table_name.replace("`", "").split(".") + for table_name in self.connection.dependencies.topo_sort() ) if d == self.database ] diff --git a/datajoint/settings.py b/datajoint/settings.py index cdf27891d..0b7bcad90 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -1,5 +1,5 @@ """ -Settings for DataJoint. +Settings for DataJoint """ from contextlib import contextmanager @@ -48,7 +48,8 @@ "database.use_tls": None, "enable_python_native_blobs": True, # python-native/dj0 encoding support "add_hidden_timestamp": False, - "filepath_checksum_size_limit": None, # file size limit for when to disable checksums + # file size limit for when to disable checksums + "filepath_checksum_size_limit": None, } ) @@ -117,6 +118,7 @@ def load(self, filename): if filename is None: filename = LOCALCONFIG with open(filename, "r") as fid: + logger.info(f"Reading dj.config from {filename}") self._conf.update(json.load(fid)) def save_local(self, verbose=False): @@ -236,7 +238,8 @@ class __Config: def __init__(self, *args, **kwargs): self._conf = dict(default) - self._conf.update(dict(*args, **kwargs)) # use the free update to set keys + # use the free update to set keys + self._conf.update(dict(*args, **kwargs)) def __getitem__(self, key): return self._conf[key] @@ -250,7 +253,9 @@ def __setitem__(self, key, value): valid_logging_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} if key == "loglevel": if value not in valid_logging_levels: - raise ValueError(f"{'value'} is not a valid logging value") + raise ValueError( + f"{'value'} is not a valid logging value {tuple(valid_logging_levels)}" + ) logger.setLevel(value) @@ -292,6 +297,8 @@ def __setitem__(self, key, value): ) if v is not None } -config.update(mapping) +if mapping: + logger.info(f"Loaded settings {tuple(mapping)} from environment variables.") + config.update(mapping) logger.setLevel(log_levels[config["loglevel"]]) diff --git a/datajoint/version.py b/datajoint/version.py index 6bcf0e20a..cc1d88710 100644 --- a/datajoint/version.py +++ b/datajoint/version.py @@ -1,3 +1,3 @@ -__version__ = "0.14.3" +__version__ = "0.14.4" assert len(__version__) <= 10 # The log table limits version to the 10 characters From 840585359e3bdf7b0da55fc88cb42eb7d01b76f8 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Mon, 17 Feb 2025 22:52:12 -0600 Subject: [PATCH 02/13] improve connection log display --- datajoint/connection.py | 2 +- datajoint/settings.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/datajoint/connection.py b/datajoint/connection.py index 26ccb540b..ba51d1679 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -195,7 +195,7 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None) connect_host_hook(self) if self.is_connected: logger.info( - "DataJoint {version} connected {user}@{host}:{port}".format( + "DataJoint {version} connected to {user}@{host}:{port}".format( version=__version__, **self.conn_info ) ) diff --git a/datajoint/settings.py b/datajoint/settings.py index 0b7bcad90..f1c300029 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -118,7 +118,7 @@ def load(self, filename): if filename is None: filename = LOCALCONFIG with open(filename, "r") as fid: - logger.info(f"Reading dj.config from {filename}") + logger.info(f"DataJoint is configured from {os.path.abspath(filename)}") self._conf.update(json.load(fid)) def save_local(self, verbose=False): @@ -254,7 +254,7 @@ def __setitem__(self, key, value): if key == "loglevel": if value not in valid_logging_levels: raise ValueError( - f"{'value'} is not a valid logging value {tuple(valid_logging_levels)}" + f"'{value}' is not a valid logging value {tuple(valid_logging_levels)}" ) logger.setLevel(value) @@ -265,11 +265,9 @@ def __setitem__(self, key, value): os.path.expanduser(n) for n in (LOCALCONFIG, os.path.join("~", GLOBALCONFIG)) ) try: - config_file = next(n for n in config_files if os.path.exists(n)) + config.load(next(n for n in config_files if os.path.exists(n))) except StopIteration: - pass -else: - config.load(config_file) + logger.info("No config file was found.") # override login credentials with environment variables mapping = { @@ -298,7 +296,7 @@ def __setitem__(self, key, value): if v is not None } if mapping: - logger.info(f"Loaded settings {tuple(mapping)} from environment variables.") + logger.info(f"Overloaded settings {tuple(mapping)} from environment variables.") config.update(mapping) logger.setLevel(log_levels[config["loglevel"]]) From 7846f9f1b6126a05077244ffaa648d4b01f6d4ea Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 18 Feb 2025 05:10:28 -0600 Subject: [PATCH 03/13] implement tri-partite make (fix #1170) --- datajoint/autopopulate.py | 36 +++++++++++++++++++++++++++++++++--- pyproject.toml | 1 + 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 0e16ee29b..a2516dda5 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -12,6 +12,7 @@ import signal import multiprocessing as mp import contextlib +import deepdiff # noinspection PyExceptionInherit,PyCallingNonCallable @@ -309,17 +310,46 @@ def _populate1( ): return False - self.connection.start_transaction() + # if make is a generator, it transaction can be delayed until the final stage + is_generator = inspect.isgeneratorfunction(make) + if not is_generator: + self.connection.start_transaction() + if key in self.target: # already populated - self.connection.cancel_transaction() + if not is_generator: + self.connection.cancel_transaction() if jobs is not None: jobs.complete(self.target.table_name, self._job_key(key)) return False logger.debug(f"Making {key} -> {self.target.full_table_name}") self.__class__._allow_insert = True + try: - make(dict(key), **(make_kwargs or {})) + if not is_generator: + make(dict(key), **(make_kwargs or {})) + else: + # tripartite make - transaction is delayed until the final stage + gen = make(dict(key), **(make_kwargs or {})) + fetched_data = next(gen) + fetch_hash = deepdiff.DeepHash( + fetched_data, ignore_iterable_order=False + )[fetched_data] + computed_result = next(gen) # perform the computation + gen = make(dict(key), **(make_kwargs or {})) # restart make + # fetch and insert inside a transaction + self.connnection.start_transaction() + fetched_data = next(gen) + if ( + fetch_hash + != deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[ + fetched_data + ] + ): # rollback due to referential integrity fail + self.connection.cancel_transaction() + return False + gen.send(computed_result) # insert + except (KeyboardInterrupt, SystemExit, Exception) as error: try: self.connection.cancel_transaction() diff --git a/pyproject.toml b/pyproject.toml index 097d168e1..1eb8c723d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.14.3" dependencies = [ "numpy", "pymysql>=0.7.2", + "deepdiff", "pyparsing", "ipython", "pandas", From bff849982b266785894dceed224257e1fc7f17c0 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Tue, 18 Feb 2025 05:34:15 -0600 Subject: [PATCH 04/13] typo --- datajoint/autopopulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index a2516dda5..21fa4a8db 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -338,7 +338,7 @@ def _populate1( computed_result = next(gen) # perform the computation gen = make(dict(key), **(make_kwargs or {})) # restart make # fetch and insert inside a transaction - self.connnection.start_transaction() + self.connection.start_transaction() fetched_data = next(gen) if ( fetch_hash From 6d6246b4b6cd29467b15dcbc2f59f5cef91bf8d1 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 20 Feb 2025 10:56:21 -0600 Subject: [PATCH 05/13] fix transaction timing in generative make. --- datajoint/autopopulate.py | 2 +- datajoint/connection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 21fa4a8db..6d72b7aa7 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -336,9 +336,9 @@ def _populate1( fetched_data, ignore_iterable_order=False )[fetched_data] computed_result = next(gen) # perform the computation - gen = make(dict(key), **(make_kwargs or {})) # restart make # fetch and insert inside a transaction self.connection.start_transaction() + gen = make(dict(key), **(make_kwargs or {})) # restart make fetched_data = next(gen) if ( fetch_hash diff --git a/datajoint/connection.py b/datajoint/connection.py index ba51d1679..5d2fbc27e 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -349,7 +349,7 @@ def query( except errors.LostConnectionError: if not reconnect: raise - logger.warning("MySQL server has gone away. Reconnecting to the server.") + logger.warning("Reconnecting to MySQL server.") connect_host_hook(self) if self._in_transaction: self.cancel_transaction() From 3199117c79368e041525f5a30da8b58ef1a9dfbe Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 21 Feb 2025 17:43:00 -0600 Subject: [PATCH 06/13] replace collections.abc.ByteString with collections.abc.Buffer --- datajoint/blob.py | 2 +- datajoint/declare.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index 891522fd2..43cf447fc 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -204,7 +204,7 @@ def pack_blob(self, obj): return self.pack_dict(obj) if isinstance(obj, str): return self.pack_string(obj) - if isinstance(obj, collections.abc.ByteString): + if isinstance(obj, collections.abc.Buffer): return self.pack_bytes(obj) if isinstance(obj, collections.abc.MutableSequence): return self.pack_list(obj) diff --git a/datajoint/declare.py b/datajoint/declare.py index b1194880f..d65bab5a3 100644 --- a/datajoint/declare.py +++ b/datajoint/declare.py @@ -398,7 +398,9 @@ def _make_attribute_alter(new, old, primary_key): command=( "ADD" if (old_name or new_name) not in old_names - else "MODIFY" if not old_name else "CHANGE `%s`" % old_name + else "MODIFY" + if not old_name + else "CHANGE `%s`" % old_name ), new_def=new_def, after="" if after is None else "AFTER `%s`" % after, From bd1999ed4cd9e558fc398e5b91c01712ff797dbe Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 21 Feb 2025 17:59:22 -0600 Subject: [PATCH 07/13] fix #1201 --- datajoint/external.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/external.py b/datajoint/external.py index a3a546e22..57b75d46d 100644 --- a/datajoint/external.py +++ b/datajoint/external.py @@ -278,7 +278,7 @@ def upload_filepath(self, local_filepath): # check if the remote file already exists and verify that it matches check_hash = (self & {"hash": uuid}).fetch("contents_hash") - if check_hash: + if check_hash.size: # the tracking entry exists, check that it's the same file as before if contents_hash != check_hash[0]: raise DataJointError( From 5c5a6e068aadc4d104035132699f52954e59075c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 21 Feb 2025 18:13:15 -0600 Subject: [PATCH 08/13] replace collections.abc.ByteString with (bytes, bytearray). --- datajoint/blob.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index 43cf447fc..a7a211210 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -8,6 +8,7 @@ import collections from decimal import Decimal import datetime +import typing as tp import uuid import numpy as np from .errors import DataJointError @@ -204,7 +205,7 @@ def pack_blob(self, obj): return self.pack_dict(obj) if isinstance(obj, str): return self.pack_string(obj) - if isinstance(obj, collections.abc.Buffer): + if isinstance(obj, (bytes, bytearray)): return self.pack_bytes(obj) if isinstance(obj, collections.abc.MutableSequence): return self.pack_list(obj) From b8a3a8ed7ad0fc96f769150d211c198dd85c73e4 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 21 Feb 2025 18:16:45 -0600 Subject: [PATCH 09/13] minor --- datajoint/blob.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index a7a211210..f38525477 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -8,7 +8,6 @@ import collections from decimal import Decimal import datetime -import typing as tp import uuid import numpy as np from .errors import DataJointError From 3d1f7cbf9ffa4fced86819baab6f1d1747f029c5 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 21 Feb 2025 18:32:58 -0600 Subject: [PATCH 10/13] formatting --- datajoint/blob.py | 8 ++++---- datajoint/external.py | 4 ++-- datajoint/preview.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datajoint/blob.py b/datajoint/blob.py index f38525477..6738ebc08 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -113,14 +113,14 @@ def unpack(self, blob): try: # decompress prefix = next( - p for p in compression if self._blob[self._pos :].startswith(p) + p for p in compression if self._blob[self._pos:].startswith(p) ) except StopIteration: pass # assume uncompressed but could be unrecognized compression else: self._pos += len(prefix) blob_size = self.read_value() - blob = compression[prefix](self._blob[self._pos :]) + blob = compression[prefix](self._blob[self._pos:]) assert len(blob) == blob_size self._blob = blob self._pos = 0 @@ -558,7 +558,7 @@ def pack_uuid(obj): def read_zero_terminated_string(self): target = self._blob.find(b"\0", self._pos) - data = self._blob[self._pos : target].decode() + data = self._blob[self._pos:target].decode() self._pos = target + 1 return data @@ -571,7 +571,7 @@ def read_value(self, dtype=None, count=1): def read_binary(self, size): self._pos += int(size) - return self._blob[self._pos - int(size) : self._pos] + return self._blob[self._pos - int(size):self._pos] def pack(self, obj, compress): self.protocol = b"mYm\0" # will be replaced with dj0 if new features are used diff --git a/datajoint/external.py b/datajoint/external.py index 57b75d46d..faac5fb09 100644 --- a/datajoint/external.py +++ b/datajoint/external.py @@ -22,10 +22,10 @@ def subfold(name, folds): """ - subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde'] + subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde'] """ return ( - (name[: folds[0]].lower(),) + subfold(name[folds[0] :], folds[1:]) + (name[:folds[0]].lower(),) + subfold(name[folds[0]:], folds[1:]) if folds else () ) diff --git a/datajoint/preview.py b/datajoint/preview.py index 775570432..472eddc38 100644 --- a/datajoint/preview.py +++ b/datajoint/preview.py @@ -52,7 +52,7 @@ def repr_html(query_expression): info = heading.table_status tuples = rel.fetch(limit=config["display.limit"] + 1, format="array") has_more = len(tuples) > config["display.limit"] - tuples = tuples[0 : config["display.limit"]] + tuples = tuples[0:config["display.limit"]] css = """