diff --git a/.github/workflows/Lint-and-test.yml b/.github/workflows/Lint-and-test.yml new file mode 100644 index 0000000..f128a43 --- /dev/null +++ b/.github/workflows/Lint-and-test.yml @@ -0,0 +1,42 @@ +name: Lint-and-test +on: [pull_request, workflow_call] +jobs: + call-linter-workflow: + uses: ISISComputingGroup/reusable-workflows/.github/workflows/linters.yml@main + with: + compare-branch: origin/master + python-ver: '3.11' + tests: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ "ubuntu-latest" ] + # Wide matrix of versions as this may run on a RHEL node with old python versions, + # but we also want it to work on dev machines + version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] + include: + - os: "windows-latest" + version: '3.11' + fail-fast: false + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.version }} + - name: install requirements + run: pip install -e .[dev] + - name: run pytest (linux) + run: python -m pytest + results: + if: ${{ always() }} + runs-on: ubuntu-latest + name: Final Results + needs: [call-linter-workflow, tests] + steps: + - run: exit 1 + # see https://stackoverflow.com/a/67532120/4907315 + if: >- + ${{ + contains(needs.*.result, 'failure') + || contains(needs.*.result, 'cancelled') + }} diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml deleted file mode 100644 index 82ad245..0000000 --- a/.github/workflows/linters.yml +++ /dev/null @@ -1,7 +0,0 @@ -name: Linter -on: [pull_request] -jobs: - call-workflow: - uses: ISISComputingGroup/reusable-workflows/.github/workflows/linters.yml@main - with: - compare-branch: origin/master diff --git a/.gitignore b/.gitignore index 808214e..ebc0efd 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,8 @@ relPaths.sh *.kdb *.kdbx /exp_db_populator_venv +/.venv +/.coverage +*.egg-info +logs +/build diff --git a/Jenkinsfile b/Jenkinsfile deleted file mode 100644 index e0be109..0000000 --- a/Jenkinsfile +++ /dev/null @@ -1,74 +0,0 @@ -#!groovy - -pipeline { - - // agent defines where the pipeline will run. - agent { - label { - label "Python_3_tests" - // Use custom workspace to avoid issue with long filepaths on Win32 - customWorkspace "C:/Exp_DB_Populator/${env.BRANCH_NAME}" - } - } - - triggers { - pollSCM('H/2 * * * *') - } - - stages { - stage("Checkout") { - steps { - echo "Branch: ${env.BRANCH_NAME}" - checkout scm - } - } - - stage("Build") { - steps { - echo "Build Number: ${env.BUILD_NUMBER}" - script { - env.GIT_COMMIT = bat(returnStdout: true, script: '@git rev-parse HEAD').trim() - env.GIT_BRANCH = bat(returnStdout: true, script: '@git rev-parse --abbrev-ref HEAD').trim() - echo "git commit: ${env.GIT_COMMIT}" - echo "git branch: ${env.BRANCH_NAME} ${env.GIT_BRANCH}" - if (env.BRANCH_NAME != null && env.BRANCH_NAME.startsWith("Release")) { - env.IS_RELEASE = "YES" - env.RELEASE_VERSION = "${env.BRANCH_NAME}".replace('Release_', '') - echo "release version: ${env.RELEASE_VERSION}" - } - else { - env.IS_RELEASE = "NO" - env.RELEASE_VERSION = "" - } - } - - bat """ - C:/Instrument/Apps/Python3/python.exe -m pip install virtualenv - C:/Instrument/Apps/Python3/Scripts/pip.exe install virtualenv - C:/Instrument/Apps/Python3/Scripts/virtualenv.exe my_python - call my_python\\Scripts\\activate.bat - call my_python\\Scripts\\pip.exe install -r requirements.txt - python.exe run_tests.py --output_dir test-reports - """ - } - } - stage("Unit Test Results") { - steps { - junit "test-reports/**/*.xml" - } - } - } - - post { - failure { - step([$class: 'Mailer', notifyEveryUnstableBuild: true, recipients: 'icp-buildserver@lists.isis.rl.ac.uk', sendToIndividuals: true]) - } - } - - // The options directive is for configuration that applies to the whole job. - options { - buildDiscarder(logRotator(numToKeepStr:'5', daysToKeepStr: '7')) - timeout(time: 60, unit: 'MINUTES') - disableConcurrentBuilds() - } -} diff --git a/create_rb_number_populator_python_venv.sh b/create_rb_number_populator_python_venv.sh index 06cd718..06bd5de 100644 --- a/create_rb_number_populator_python_venv.sh +++ b/create_rb_number_populator_python_venv.sh @@ -3,6 +3,6 @@ venv="exp_db_populator_venv" # Name of the virtual environment /usr/local/bin/python3.8 -m venv /home/epics/RB_num_populator/$venv # create virtual environment source $venv/bin/activate # activate the virtual environment -$venv/bin/pip install -r requirements.txt # Install requirements.txt +$venv/bin/pip install -e . # Install requirements.txt deactivate # deactivate the virtual environment echo "Virtual environment created" diff --git a/main.py b/exp_db_populator/cli.py similarity index 77% rename from main.py rename to exp_db_populator/cli.py index dd04b76..00375fa 100644 --- a/main.py +++ b/exp_db_populator/cli.py @@ -3,6 +3,12 @@ from datetime import datetime from logging.handlers import TimedRotatingFileHandler +from exp_db_populator.data_types import InstList, RawDataEntry +from exp_db_populator.webservices_test_data import ( + TEST_USER_1, + create_web_data_with_experimenters_and_other_date, +) + # Loging must be handled here as some imports might log errors log_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "logs") if not os.path.exists(log_folder): @@ -21,31 +27,27 @@ import json import threading import zlib +from typing import Any import epics -from six.moves import input from exp_db_populator.gatherer import Gatherer from exp_db_populator.populator import update from exp_db_populator.webservices_reader import reformat_data -from tests.webservices_test_data import ( - TEST_USER_1, - create_web_data_with_experimenters_and_other_date, -) # PV that contains the instrument list INST_LIST_PV = "CS:INSTLIST" -def convert_inst_list(value_from_PV): +def convert_inst_list(value_from_pv: str) -> InstList: """ Converts the instrument list coming from the PV into a dictionary. Args: - value_from_PV: The raw value from the PV. + value_from_pv: The raw value from the PV. Returns: dict: The instrument information. """ - json_string = zlib.decompress(bytes.fromhex(value_from_PV)).decode("utf-8") + json_string = zlib.decompress(bytes.fromhex(value_from_pv)).decode("utf-8") return json.loads(json_string) @@ -55,30 +57,33 @@ class InstrumentPopulatorRunner: """ gatherer = None - prev_inst_list = None + prev_inst_list: InstList | None = None db_lock = threading.RLock() - def __init__(self, run_continuous=False): + def __init__(self, run_continuous: bool = False) -> None: self.run_continuous = run_continuous - def start_inst_list_monitor(self): + def start_inst_list_monitor(self) -> None: logging.info("Setting up monitors on {}".format(INST_LIST_PV)) self.inst_list_callback(char_value=epics.caget(INST_LIST_PV, as_string=True)) epics.camonitor(INST_LIST_PV, callback=self.inst_list_callback) - def inst_list_callback(self, char_value, **kw): + def inst_list_callback(self, char_value: str | None, **_kw: dict[str, Any]) -> None: """ Called when the instrument list PV changes value. Args: char_value: The string representation of the PV data. - **kw: The module will also send other info about the PV, we capture this and don't use it. + **kw: The module will also send other info about the PV, we capture this and don't + use it. """ + if char_value is None: + return new_inst_list = convert_inst_list(char_value) if new_inst_list != self.prev_inst_list: self.prev_inst_list = new_inst_list self.inst_list_changes(new_inst_list) - def remove_gatherer(self): + def remove_gatherer(self) -> None: """ Stops the gatherer and clears the cache. """ @@ -88,7 +93,7 @@ def remove_gatherer(self): self.wait_for_gatherer_to_finish() self.gatherer = None - def inst_list_changes(self, inst_list): + def inst_list_changes(self, inst_list: InstList) -> None: """ Starts a new gatherer thread. Args: @@ -102,14 +107,15 @@ def inst_list_changes(self, inst_list): new_gatherer.start() self.gatherer = new_gatherer - def wait_for_gatherer_to_finish(self): + def wait_for_gatherer_to_finish(self) -> None: """ Blocks until gatherer is finished. """ - self.gatherer.join() + if self.gatherer is not None: + self.gatherer.join() -if __name__ == "__main__": +def main_cli() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--cont", @@ -135,13 +141,17 @@ def wait_for_gatherer_to_finish(self): main = InstrumentPopulatorRunner(args.cont) if args.as_instrument: - debug_inst_list = [ + debug_inst_list: InstList = [ {"name": args.as_instrument, "hostName": "localhost", "isScheduled": True} ] main.prev_inst_list = debug_inst_list main.inst_list_changes(debug_inst_list) elif args.test_data: - data = [create_web_data_with_experimenters_and_other_date([TEST_USER_1], datetime.now())] + data: list[RawDataEntry] = [ + create_web_data_with_experimenters_and_other_date([TEST_USER_1], datetime.now()) + ] + if not args.db_user or not args.db_pass: + raise ValueError("Must specify a username and password if using test data") update( "localhost", "localhost", @@ -164,8 +174,15 @@ def wait_for_gatherer_to_finish(self): main.remove_gatherer() running = False elif menu_input == "U": - main.inst_list_changes(main.prev_inst_list) + if main.prev_inst_list is None: + logging.warning("No previous instrument list") + else: + main.inst_list_changes(main.prev_inst_list) else: logging.warning("Command not recognised: {}".format(menu_input)) else: main.wait_for_gatherer_to_finish() + + +if __name__ == "__main__": + main_cli() diff --git a/exp_db_populator/data_types.py b/exp_db_populator/data_types.py index 9d536e6..b293e04 100644 --- a/exp_db_populator/data_types.py +++ b/exp_db_populator/data_types.py @@ -1,32 +1,42 @@ +from datetime import datetime +from typing import NotRequired, TypeAlias, TypedDict + from exp_db_populator.database_model import Role, User # The group in which the credentials are stored CREDS_GROUP = "ExpDatabasePopulator" +RbNumber: TypeAlias = str +SessionId: TypeAlias = str + class UserData: """ A class for holding all the data required for a row in the user table. """ - def __init__(self, name, organisation): + def __init__(self, name: str, organisation: str) -> None: self.name = name self.organisation = organisation - def __str__(self): + def __str__(self) -> str: return "User {} is from {}".format(self.name, self.organisation) @property - def user_id(self): + def user_id(self) -> tuple[User, bool]: """ - Gets the user id for the user. Will create an entry for the user in the database if one doesn't exist. + Gets the user id for the user. Will create an entry for the user in + the database if one doesn't exist. Returns: the user's id. """ return User.get_or_create(name=self.name, organisation=self.organisation)[0].userid - def __eq__(self, other): - return self.name == other.name and self.organisation == other.organisation + def __eq__(self, other: object) -> bool: + if isinstance(other, UserData): + return self.name == other.name and self.organisation == other.organisation + else: + return False class ExperimentTeamData: @@ -34,26 +44,62 @@ class ExperimentTeamData: A class for holding all the data required for a row in the experiment team table. """ - def __init__(self, user, role, rb_number, start_date): + def __init__( + self, user: UserData, role: str, rb_number: RbNumber, start_date: datetime + ) -> None: self.user = user self.role = role self.rb_number = rb_number self.start_date = start_date @property - def role_id(self): + def role_id(self) -> tuple[Role, bool]: """ - Gets the role id for the user based on the roles in the database. Will raise an exception if role is not found. + Gets the role id for the user based on the roles in the database. + Will raise an exception if role is not found. Returns: the role id. """ return Role.get(Role.name == self.role).roleid - def __eq__(self, other): - return ( - self.user == other.user - and self.role == other.role - and self.rb_number == other.rb_number - and self.start_date == other.start_date - ) + def __eq__(self, other: object) -> bool: + if isinstance(other, ExperimentTeamData): + return ( + self.user == other.user + and self.role == other.role + and self.rb_number == other.rb_number + and self.start_date == other.start_date + ) + else: + return False + + +class InstListEntry(TypedDict): + name: str + hostName: str + isScheduled: bool + + +InstList: TypeAlias = list[InstListEntry] + + +class Experimenter(TypedDict): + name: str + organisation: str + role: str + + +class RawDataEntry(TypedDict): + instrument: str + lcName: str + part: int + rbNumber: str + scheduledDate: datetime + timeAllocated: float + experimenters: NotRequired[list[Experimenter]] + + +RawData: TypeAlias = dict[str, RawDataEntry] + +Credentials: TypeAlias = tuple[str, str] | None diff --git a/exp_db_populator/database_model.py b/exp_db_populator/database_model.py index 1a5944c..12ab1b1 100644 --- a/exp_db_populator/database_model.py +++ b/exp_db_populator/database_model.py @@ -1,6 +1,15 @@ -from peewee import * - -# Model built using peewiz +from typing import Any, Literal, overload + +from peewee import ( + AutoField, + CharField, + CompositeKey, + DateTimeField, + ForeignKeyField, + IntegerField, + Model, + Proxy, +) database_proxy = Proxy() @@ -15,29 +24,62 @@ class Experiment(BaseModel): experimentid = CharField(column_name="experimentID") startdate = DateTimeField(column_name="startDate") - class Meta: + class Meta: # pyright: ignore table_name = "experiment" indexes = ((("experimentid", "startdate"), True),) primary_key = CompositeKey("experimentid", "startdate") + @overload + def __getitem__(self, itm: Literal["duration"]) -> int | None: + pass + + @overload + def __getitem__(self, itm: Literal["experimentid"]) -> str: + pass + + def __getitem__(self, itm: str) -> Any: + return super().__getitem__(itm) # pyright: ignore (pyright can't see __getitem__) + class Role(BaseModel): name = CharField(null=True) priority = IntegerField(null=True) roleid = AutoField(column_name="roleID") - class Meta: + class Meta: # pyright: ignore table_name = "role" + @overload + def __getitem__(self, itm: Literal["name"] | Literal["role"]) -> str | None: + pass + + @overload + def __getitem__(self, itm: Literal["priority"]) -> int | None: + pass + + def __getitem__(self, itm: str) -> Any: + return super().__getitem__(itm) # pyright: ignore (pyright can't see __getitem__) + class User(BaseModel): name = CharField(null=True) organisation = CharField(null=True) userid = AutoField(column_name="userID") - class Meta: + class Meta: # pyright: ignore table_name = "user" + @overload + def __getitem__(self, itm: Literal["name"] | Literal["organisation"]) -> str: + pass + + @overload + def __getitem__(self, itm: Literal["userid"]) -> int: + pass + + def __getitem__(self, itm: str) -> Any: + return super().__getitem__(itm) # pyright: ignore (pyright can't see __getitem__) + class Experimentteams(BaseModel): experimentid = ForeignKeyField( @@ -52,7 +94,7 @@ class Experimentteams(BaseModel): ) userid = ForeignKeyField(column_name="userID", field="userid", model=User) - class Meta: + class Meta: # pyright: ignore table_name = "experimentteams" indexes = ( (("experimentid", "startdate"), False), diff --git a/exp_db_populator/gatherer.py b/exp_db_populator/gatherer.py index 885e7d2..082aeaa 100644 --- a/exp_db_populator/gatherer.py +++ b/exp_db_populator/gatherer.py @@ -2,13 +2,14 @@ import threading from time import sleep +from exp_db_populator.data_types import InstList, RawDataEntry from exp_db_populator.populator import update from exp_db_populator.webservices_reader import gather_data, reformat_data POLLING_TIME = 3600 # Time in seconds between polling the website -def correct_name(old_name): +def correct_name(old_name: str) -> str: """ Some names are different between IBEX and the web data, this function converts these. Args: @@ -19,7 +20,7 @@ def correct_name(old_name): return "ENGIN-X" if old_name == "ENGINX" else old_name -def filter_instrument_data(raw_data, inst_name): +def filter_instrument_data(raw_data: list[RawDataEntry], inst_name: str) -> list[RawDataEntry]: """ Gets the data associated with the specified instrument. Args: @@ -39,7 +40,9 @@ class Gatherer(threading.Thread): running = True - def __init__(self, inst_list, db_lock, run_continuous=False): + def __init__( + self, inst_list: InstList, db_lock: threading.RLock, run_continuous: bool = False + ) -> None: threading.Thread.__init__(self) self.daemon = True self.inst_list = inst_list @@ -47,19 +50,21 @@ def __init__(self, inst_list, db_lock, run_continuous=False): self.db_lock = db_lock logging.info("Starting gatherer") - def run(self): + def run(self) -> None: """ Periodically runs to gather new data and populate the databases. """ while self.running: - all_data = gather_data() + all_data: list[RawDataEntry] = gather_data() + for inst in self.inst_list: if inst["isScheduled"]: name, host = correct_name(inst["name"]), inst["hostName"] instrument_list = filter_instrument_data(all_data, name) if not instrument_list: logging.error( - f"Unable to update {name}, no data found. Expired data will still be cleared." + f"Unable to update {name}, no data found. " + f"Expired data will still be cleared." ) data_to_populate = None else: diff --git a/exp_db_populator/populator.py b/exp_db_populator/populator.py index 237f032..feed28f 100644 --- a/exp_db_populator/populator.py +++ b/exp_db_populator/populator.py @@ -1,47 +1,58 @@ import logging +import threading from datetime import datetime, timedelta from peewee import MySQLDatabase, chunked -from exp_db_populator.data_types import CREDS_GROUP +from exp_db_populator.data_types import CREDS_GROUP, Credentials, RawDataEntry from exp_db_populator.database_model import Experiment, Experimentteams, User, database_proxy try: from exp_db_populator.passwords.password_reader import get_credentials except ImportError: - logging.warn( - "Password submodule not found, will not be able to write to databases, " - "unless username/password are specified manually" + err = ( + "Password submodule not found, will not be able to write to " + "databases, unless username/password are specified manually" ) + logging.warn(err) -AGE_OF_EXPIRATION = 100 # How old (in days) the startdate of an experiment must be before it is removed from the database -POLLING_TIME = 3600 # Time in seconds between polling the website + def get_credentials(group_str: str, entry_str: str) -> Credentials: + raise EnvironmentError(err) -def remove_users_not_referenced(): +# How old (in days) the startdate of an experiment must be before it is removed from the database +AGE_OF_EXPIRATION = 100 + +# Time in seconds between polling the website +POLLING_TIME = 3600 + + +def remove_users_not_referenced() -> None: all_team_user_ids = Experimentteams.select(Experimentteams.userid) - User.delete().where(User.userid.not_in(all_team_user_ids)).execute() + User.delete().where(User.userid.not_in(all_team_user_ids)).execute() # pyright: ignore (doesn't understand peewee) -def remove_experiments_not_referenced(): +def remove_experiments_not_referenced() -> None: all_team_experiments = Experimentteams.select(Experimentteams.experimentid) - Experiment.delete().where(Experiment.experimentid.not_in(all_team_experiments)).execute() + Experiment.delete().where(Experiment.experimentid.not_in(all_team_experiments)).execute() # pyright: ignore (doesn't understand peewee) -def remove_old_experiment_teams(age): +def remove_old_experiment_teams(age: float) -> None: date = datetime.now() - timedelta(days=age) - Experimentteams.delete().where(Experimentteams.startdate < date).execute() + Experimentteams.delete().where(Experimentteams.startdate < date).execute() # pyright: ignore (doesn't understand peewee) + + +def create_database(instrument_host: str, credentials: Credentials) -> MySQLDatabase: + credentials = credentials or get_credentials(CREDS_GROUP, "ExpDatabaseWrite") + if credentials is None: + raise ValueError("Cannot connect to db, no credentials.") -def create_database(instrument_host, credentials): - if not credentials: - username, password = get_credentials(CREDS_GROUP, "ExpDatabaseWrite") - else: - username, password = credentials + username, password = credentials return MySQLDatabase("exp_data", user=username, password=password, host=instrument_host) -def cleanup_old_data(): +def cleanup_old_data() -> None: """ Removes old data from the database. """ @@ -50,13 +61,14 @@ def cleanup_old_data(): remove_users_not_referenced() -def populate(experiments, experiment_teams): +def populate(experiments: list[RawDataEntry], experiment_teams: list) -> None: """ Populates the database with experiment data. Args: experiments (list[dict]): A list of dictionaries containing information on experiments. - experiment_teams (list[exp_db_populator.data_types.ExperimentTeamData]): A list containing the users for all new experiments. + experiment_teams (list[exp_db_populator.data_types.ExperimentTeamData]): A list containing + the users for all new experiments. """ if not experiments or not experiment_teams: raise KeyError("Experiment without team or vice versa") @@ -79,13 +91,13 @@ def populate(experiments, experiment_teams): def update( - instrument_name, - instrument_host, - db_lock, - instrument_data, - run_continuous=False, - credentials=None, -): + instrument_name: str, + instrument_host: str, + db_lock: threading.RLock, + instrument_data: tuple[list[RawDataEntry], list[Experimentteams]] | None, + run_continuous: bool = False, + credentials: Credentials = None, +) -> None: """ Populates the database with this experiment's data. @@ -93,10 +105,11 @@ def update( instrument_name: The name of the instrument to update. instrument_host: The host name of the instrument to update. db_lock: A lock for writing to the database. - instrument_data: The data to send to the instrument, if None the data will just be cleared instead. - run_continuous: Whether or not the program is running in continuous mode. - credentials: The credentials to write to the database with, in the form (user, password). If None then the - credentials are received from the stored git repo + instrument_data: The data to send to the instrument, if None the data will just be + cleared instead. + run_continuous: Whether the program is running in continuous mode. + credentials: The credentials to write to the database with, in the form (user, password). + If None then the credentials are received from the stored git repo """ database = create_database(instrument_host, credentials) logging.info( diff --git a/exp_db_populator/webservices_reader.py b/exp_db_populator/webservices_reader.py index 6873292..5d8c996 100644 --- a/exp_db_populator/webservices_reader.py +++ b/exp_db_populator/webservices_reader.py @@ -1,46 +1,57 @@ # -*- coding: utf-8 -*- import logging import math -import ssl +import typing from datetime import datetime, timedelta +import requests from suds.client import Client -from exp_db_populator.data_types import CREDS_GROUP, ExperimentTeamData, UserData +from exp_db_populator.data_types import ( + CREDS_GROUP, + Credentials, + Experimenter, + ExperimentTeamData, + RawDataEntry, + RbNumber, + SessionId, + UserData, +) from exp_db_populator.database_model import Experiment try: from exp_db_populator.passwords.password_reader import get_credentials except ImportError: - logging.warn("Password submodule not found, will not be able to read from web") + err = "Password submodule not found, will not be able to read from web" + + logging.warn(err) + + def get_credentials(group_str: str, entry_str: str) -> Credentials: + raise EnvironmentError(err) + LOCAL_ORG = "Science and Technology Facilities Council" LOCAL_ROLE = "Contact" RELEVANT_DATE_RANGE = 100 # How many days of data to gather (either side of now) DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S" -BUS_APPS_SITE = "https://api.facilities.rl.ac.uk/ws/" -BUS_APPS_AUTH = BUS_APPS_SITE + "UserOfficeWebService?wsdl" -BUS_APPS_API = BUS_APPS_SITE + "ScheduleWebService?wsdl" +BUS_APPS_SITE = "https://api.facilities.rl.ac.uk/" +BUS_APPS_AUTH = BUS_APPS_SITE + "users-service/v1/sessions" +BUS_APPS_API = BUS_APPS_SITE + "ws/ScheduleWebService?wsdl" -# This is a workaround because the web service does not have a valid certificate -if hasattr(ssl, "_create_unverified_context"): - ssl._create_default_https_context = ssl._create_unverified_context +SUCCESSFUL_LOGIN_STATUS_CODE = 201 -def get_start_and_end(date, time_range_days): +def get_start_and_end(date: datetime, time_range_days: int) -> tuple[datetime, datetime]: days = timedelta(days=time_range_days) return date - days, date + days -def get_experimenters(team): - try: - return team.experimenters - except AttributeError: - return [] +def get_experimenters(team: RawDataEntry) -> list[Experimenter]: + return team.get("experimenters", []) -def create_date_range(client): +def create_date_range(client: Client) -> typing.Any: # noqa ANN401 rpc call """ Creates a date range in a format for the web client to understand. """ @@ -51,16 +62,28 @@ def create_date_range(client): return date_range -def connect(): +def connect() -> tuple[Client, SessionId]: """ Connects to the busapps website. Returns: tuple: the client and the associated session id. """ try: - username, password = get_credentials(CREDS_GROUP, "WebRead") + creds = get_credentials(CREDS_GROUP, "WebRead") + if creds is None: + raise EnvironmentError("No credentials provided") + + username, password = creds + + response = requests.post(BUS_APPS_AUTH, json={"username": username, "password": password}) + + if response.status_code != SUCCESSFUL_LOGIN_STATUS_CODE: + raise IOError( + f"Failed to authenticate to busapps web service, " + f"code={response.status_code}, resp={response.text}" + ) - session_id = Client(BUS_APPS_AUTH).service.login(username, password) + session_id = response.json()["sessionId"] client = Client(BUS_APPS_API) return client, session_id @@ -69,7 +92,7 @@ def connect(): raise -def get_all_data_from_web(client, session_id): +def get_all_data_from_web(client: Client, session_id: SessionId) -> list[RawDataEntry]: """ Args: client: The client that has connected to the web. @@ -89,7 +112,9 @@ def get_all_data_from_web(client, session_id): raise -def create_exp_team(user, role, rb_number, date): +def create_exp_team( + user: UserData, role: str, rb_number: RbNumber, date: datetime +) -> ExperimentTeamData: # IBEX calls them users, BusApps calls them members if role == "Member": role = "User" @@ -97,15 +122,18 @@ def create_exp_team(user, role, rb_number, date): return ExperimentTeamData(user, role, rb_number, date) -def reformat_data(instrument_data_list): +def reformat_data( + instrument_data_list: list[RawDataEntry], +) -> tuple[list, list]: """ Reformats the data from the way the website returns it to the way the database wants it. Args: instrument_data_list (list): List of an instrument's data from the website. Returns: - tuple (list, list): A list of the experiments and their associated data and a list of the experiment teams, - and a dictionary of rb_numbers and their associated instrument.. + tuple (list, list): A list of the experiments and their associated data and a + list of the experiment teams, and a dictionary of rb_numbers and their associated + instrument. """ try: experiments = [] @@ -114,9 +142,9 @@ def reformat_data(instrument_data_list): for data in instrument_data_list: experiments.append( { - Experiment.experimentid: data["rbNumber"], - Experiment.startdate: data["scheduledDate"], - Experiment.duration: math.ceil(data["timeAllocated"]), + Experiment.experimentid: typing.cast(RbNumber, data["rbNumber"]), + Experiment.startdate: typing.cast(str, data["scheduledDate"]), + Experiment.duration: math.ceil(typing.cast(float, data["timeAllocated"])), } ) @@ -139,7 +167,7 @@ def reformat_data(instrument_data_list): raise -def gather_data(): +def gather_data() -> list[RawDataEntry]: client, session_id = connect() data = get_all_data_from_web(client, session_id) return data diff --git a/tests/webservices_test_data.py b/exp_db_populator/webservices_test_data.py similarity index 66% rename from tests/webservices_test_data.py rename to exp_db_populator/webservices_test_data.py index 31e3502..d534d83 100644 --- a/tests/webservices_test_data.py +++ b/exp_db_populator/webservices_test_data.py @@ -1,6 +1,7 @@ from datetime import datetime +from typing import Any -from mock import MagicMock +from exp_db_populator.data_types import RawDataEntry TEST_INSTRUMENT = "test_instrument" TEST_OTHER_INSTRUMENT = "test_other_instrument" @@ -27,21 +28,17 @@ ] -def get_test_experiment_team(experimenters): +def get_test_experiment_team(experimenters: list) -> dict[str, Any]: team_dict = { "experimenters": experimenters, "instrument": TEST_INSTRUMENT, "part": 6, "rbNumber": TEST_RBNUMBER, } + return team_dict - team = MagicMock() - team.experimenters = experimenters - team.__getitem__.side_effect = team_dict.__getitem__ - return team - -def create_data(rb, start, duration): +def create_data(rb: str, start: datetime, duration: float) -> RawDataEntry: return { "instrument": TEST_INSTRUMENT, "lcName": TEST_CONTACT_NAME, @@ -55,19 +52,15 @@ def create_data(rb, start, duration): TEST_DATA = [create_data(TEST_RBNUMBER, TEST_DATE, TEST_TIMEALLOCATED)] -def create_web_data_with_experimenters(experimenters): +def create_web_data_with_experimenters(experimenters: list) -> RawDataEntry: data_dict = create_data(TEST_RBNUMBER, TEST_DATE, TEST_TIMEALLOCATED) - - data = MagicMock() - data.experimenters = experimenters - data.__getitem__.side_effect = data_dict.__getitem__ - return data + data_dict["experimenters"] = experimenters + return data_dict -def create_web_data_with_experimenters_and_other_date(experimenters, date): +def create_web_data_with_experimenters_and_other_date( + experimenters: list, date: datetime +) -> RawDataEntry: data_dict = create_data(TEST_RBNUMBER, date, TEST_TIMEALLOCATED) - - data = MagicMock() - data.experimenters = experimenters - data.__getitem__.side_effect = data_dict.__getitem__ - return data + data_dict["experimenters"] = experimenters + return data_dict diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f0aad54 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,82 @@ +[build-system] +requires = ["setuptools>=64", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + + +[project] +name = "ExperimentDatabasePopulator" +dynamic = ["version"] +description = "Experiment Database Populator" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} + +authors = [ + {name = "ISIS Experiment Controls", email = "ISISExperimentControls@stfc.ac.uk" } +] +maintainers = [ + {name = "ISIS Experiment Controls", email = "ISISExperimentControls@stfc.ac.uk" } +] + +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3 :: Only", +] + +dependencies = [ + "requests", + "suds", + "pykeepass", + "peewee", + "pymysql[rsa]", + "cryptography", + "pyepics", + "mock", # Needed at runtime, not dev-only +] + +[project.optional-dependencies] +dev = [ + "pyright", + "pytest", + "pytest-cov", + "ruff>=0.9", +] + +[project.urls] +"Homepage" = "https://github.com/isiscomputinggroup/ExperimentDatabasePopulator" +"Bug Reports" = "https://github.com/isiscomputinggroup/ExperimentDatabasePopulator/issues" +"Source" = "https://github.com/isiscomputinggroup/ExperimentDatabasePopulator" + +[project.scripts] +exp_db_populator = "exp_db_populator.cli:main_cli" + +[tool.pytest.ini_options] +testpaths = "tests" +addopts = "--cov --cov-report=html -vv" + +[tool.coverage.run] +branch = true +source = ["exp_db_populator"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "if typing.TYPE_CHECKING:", + "@abstractmethod", +] + +[tool.coverage.html] +directory = "coverage_html_report" + +[tool.setuptools_scm] + +[tool.setuptools.packages.find] +include = ["exp_db_populator"] +namespaces = false + +[tool.pyright] +include = ["exp_db_populator"] +exclude = ["exp_db_populator/passwords/password_reader.py"] diff --git a/rb_number_populator.sh b/rb_number_populator.sh index 79cf580..4febb30 100755 --- a/rb_number_populator.sh +++ b/rb_number_populator.sh @@ -3,5 +3,5 @@ venv="exp_db_populator_venv" # Name of the virtual environment . /home/epics/EPICS/config_env.sh source /home/epics/RB_num_populator/$venv/bin/activate # activate the virtual environment -/home/epics/RB_num_populator/$venv/bin/python3.8 /home/epics/RB_num_populator/main.py +exp_db_populator deactivate # deactivate the virtual environment diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 3c0dd22..0000000 --- a/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -suds-jurko -pykeepass -peewee -pymysql -mock -xmlrunner -pyepics -six -cryptography \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..d122719 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,42 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".pyi", + "exp_db_populator/passwords" +] + +src = ["exp_db_populator"] + +# Set the maximum line length to 100. +line-length = 100 +indent-width = 4 + +[lint] +extend-select = [ + "N", # pep8-naming + # "D", # pydocstyle (can use this later but for now causes too many errors) + "I", # isort (for imports) + "E501", # Line too long ({width} > {limit}) + "E", + "F", + "ANN", +] +ignore = [ + "D406", # Section name should end with a newline ("{name}") + "D407", # Missing dashed underline after section ("{name}") +] +[lint.per-file-ignores] +"tests/*" = [ + "N802", + "D100", + "D101", + "D102", + "E501", + "ANN", +] +"exp_db_populator/cli.py" = [ + "E402" +] + +[lint.pydocstyle] +# Use Google-style docstrings. +convention = "google" diff --git a/run_tests.py b/run_tests.py deleted file mode 100644 index a0397f6..0000000 --- a/run_tests.py +++ /dev/null @@ -1,37 +0,0 @@ -import argparse -import os -import sys - -# Standard imports -import unittest - -import xmlrunner - -DEFAULT_DIRECTORY = os.path.join(".", "test-reports") - - -if __name__ == "__main__": - # get output directory from command line arguments - parser = argparse.ArgumentParser() - parser.add_argument( - "-o", - "--output_dir", - nargs=1, - type=str, - default=[DEFAULT_DIRECTORY], - help="The directory to save the test reports", - ) - args = parser.parse_args() - xml_dir = args.output_dir[0] - - # Load tests from test suites - test_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "tests")) - test_suite = unittest.TestLoader().discover(test_dir, pattern="test_*.py") - - print("\n\n------ BEGINNING Experiment Database Populator UNIT TESTS ------") - ret_vals = list() - ret_vals.append(xmlrunner.XMLTestRunner(output=xml_dir).run(test_suite)) - print("------ UNIT TESTS COMPLETE ------\n\n") - - # Return failure exit code if a test failed - sys.exit(False in ret_vals) diff --git a/tests/test_data_types.py b/tests/test_data_types.py index 9093599..c30a9f4 100644 --- a/tests/test_data_types.py +++ b/tests/test_data_types.py @@ -1,10 +1,15 @@ import unittest -from peewee import SqliteDatabase - import exp_db_populator.database_model as model from exp_db_populator.data_types import ExperimentTeamData, UserData -from tests.webservices_test_data import * +from exp_db_populator.webservices_test_data import ( + TEST_DATE, + TEST_PI_NAME, + TEST_PI_ORG, + TEST_PI_ROLE, + TEST_RBNUMBER, +) +from peewee import SqliteDatabase class UserDataTests(unittest.TestCase): diff --git a/tests/test_gatherer.py b/tests/test_gatherer.py index a1d2cb2..d1f1cf9 100644 --- a/tests/test_gatherer.py +++ b/tests/test_gatherer.py @@ -1,9 +1,8 @@ import threading import unittest -from mock import patch - from exp_db_populator.gatherer import Gatherer, filter_instrument_data +from mock import patch class GathererTests(unittest.TestCase): diff --git a/tests/test_main.py b/tests/test_main.py index 2c5bf7a..13284d2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,16 +1,15 @@ import unittest -from mock import Mock, patch - +from exp_db_populator.cli import InstrumentPopulatorRunner from exp_db_populator.gatherer import Gatherer -from main import InstrumentPopulatorRunner +from mock import Mock, patch class MainTest(unittest.TestCase): def setUp(self): self.inst_pop_runner = InstrumentPopulatorRunner() - @patch("main.Gatherer") + @patch("exp_db_populator.cli.Gatherer") def test_GIVEN_no_gatherer_running_WHEN_instrument_list_has_new_instrument_THEN_gatherer_starts( self, gatherer ): @@ -23,7 +22,7 @@ def test_GIVEN_no_gatherer_running_WHEN_instrument_list_has_new_instrument_THEN_ self.assertEqual(new_gather, self.inst_pop_runner.gatherer) - @patch("main.InstrumentPopulatorRunner.remove_gatherer") + @patch("exp_db_populator.cli.InstrumentPopulatorRunner.remove_gatherer") def test_WHEN_instrument_list_updated_THEN_gatherer_stopped_and_cleared(self, stop): new_name, new_host = "TEST", "NDXTEST" self.inst_pop_runner.inst_list_changes( diff --git a/tests/test_populator.py b/tests/test_populator.py index 406daa8..d0b2d0e 100644 --- a/tests/test_populator.py +++ b/tests/test_populator.py @@ -1,10 +1,8 @@ import threading import unittest +from datetime import datetime from time import sleep -from mock import Mock, patch -from peewee import SqliteDatabase - import exp_db_populator.database_model as model from exp_db_populator.data_types import ExperimentTeamData, UserData from exp_db_populator.populator import ( @@ -14,7 +12,16 @@ remove_users_not_referenced, update, ) -from tests.webservices_test_data import * +from exp_db_populator.webservices_test_data import ( + TEST_DATE, + TEST_INSTRUMENT, + TEST_PI_ROLE, + TEST_RBNUMBER, + TEST_TIMEALLOCATED, + TEST_USER_PI, +) +from mock import Mock, patch +from peewee import SqliteDatabase class PopulatorTests(unittest.TestCase): @@ -58,15 +65,15 @@ def test_GIVEN_user_and_unrelated_experiment_teams_WHEN_unreferenced_removed_THE self, ): model.User.create(name="Delete me", organisation="STFC") - KEEP_NAME = "Keep Me" - self.create_full_record(user_name=KEEP_NAME) + keep_name = "Keep Me" + self.create_full_record(user_name=keep_name) self.assertEqual(2, model.User.select().count()) remove_users_not_referenced() users = model.User.select() self.assertEqual(1, users.count()) - self.assertEqual(KEEP_NAME, users[0].name) + self.assertEqual(keep_name, users[0].name) def test_GIVEN_user_and_related_experiment_teams_WHEN_unreferenced_removed_THEN_user_remains( self, @@ -90,15 +97,15 @@ def test_GIVEN_experiment_and_unrelated_experiment_teams_WHEN_unreferenced_remov self, ): model.Experiment.create(experimentid=TEST_RBNUMBER, duration=2, startdate=TEST_DATE) - KEEP_RB = "20000" - self.create_full_record(rb_number=KEEP_RB) + keep_rb = "20000" + self.create_full_record(rb_number=keep_rb) self.assertEqual(2, model.Experiment.select().count()) remove_experiments_not_referenced() exps = model.Experiment.select() self.assertEqual(1, exps.count()) - self.assertEqual(KEEP_RB, exps[0].experimentid) + self.assertEqual(keep_rb, exps[0].experimentid) def test_GIVEN_experiment_and_related_experiment_teams_WHEN_unreferenced_removed_THEN_experiment_remains( self, @@ -133,7 +140,7 @@ def create_experiment_teams_dictionary(self): exp_team_data.user.user_id = 1 return [exp_team_data] - def test_WHEN_populate_called_with_experiments_and_no_teams_THEN_exception_raised(self): + def test_WHEN_populate_called_with_experiments_and_no_teams_2THEN_exception_raised(self): experiments = self.create_experiments_dictionary() self.assertRaises(KeyError, populate, experiments, []) diff --git a/tests/test_webservices_reader.py b/tests/test_webservices_reader.py index 56e439c..2055005 100644 --- a/tests/test_webservices_reader.py +++ b/tests/test_webservices_reader.py @@ -1,8 +1,6 @@ import unittest from datetime import datetime, timedelta -from mock import MagicMock - from exp_db_populator.data_types import ExperimentTeamData, UserData from exp_db_populator.database_model import Experiment from exp_db_populator.webservices_reader import ( @@ -13,7 +11,27 @@ get_start_and_end, reformat_data, ) -from tests.webservices_test_data import * +from exp_db_populator.webservices_test_data import ( + TEST_CONTACT_NAME, + TEST_CONTACTS, + TEST_DATA, + TEST_DATE, + TEST_PI_NAME, + TEST_PI_ORG, + TEST_PI_ROLE, + TEST_RBNUMBER, + TEST_TIMEALLOCATED, + TEST_USER_1, + TEST_USER_1_NAME, + TEST_USER_1_ORG, + TEST_USER_1_ROLE, + TEST_USER_PI, + create_data, + create_web_data_with_experimenters, + create_web_data_with_experimenters_and_other_date, + get_test_experiment_team, +) +from mock import MagicMock class WebServicesReaderTests(unittest.TestCase): @@ -26,12 +44,10 @@ def test_WHEN_get_start_and_end_date_of_100_THEN_stat_before_end(self): self.assertTrue(start < end) def test_GIVEN_experimenters_WHEN_get_experimenters_THEN_get_experimenters(self): - team = MagicMock() - team.experimenters = ["TEST"] - self.assertEqual(["TEST"], get_experimenters(team)) + self.assertEqual(["TEST"], get_experimenters({"experimenters": ["TEST"]})) def test_GIVEN_no_experimenters_WHEN_get_experimenters_THEN_empty_list(self): - team = MagicMock(spec=["NOT_EXPERIMENTEERS"]) + team = {} self.assertEqual([], get_experimenters(team)) def test_GIVEN_no_data_set_WHEN_data_formatted_THEN_no_data_set(self):