33from sqlalchemy .orm import relationship
44from copy import deepcopy
55
6- from .._db import session
6+ from .._db import Session
77from delphi .epidata .common .logger import get_structured_logger
88
99from typing import Set , Optional , List
@@ -35,61 +35,61 @@ def __init__(self, api_key: str, email: str = None) -> None:
3535
3636 @staticmethod
3737 def list_users () -> List ["User" ]:
38- return session .query (User ).all ()
38+ with Session () as session :
39+ return session .query (User ).all ()
3940
4041 @property
4142 def as_dict (self ):
42- user_dict = deepcopy (self .__dict__ ) # NOTE: changed from `self.__dict__.copy()` as it caused issues
43- # so we dont change the internal representation of self
44- user_dict ["roles" ] = self .get_user_roles
45- try :
46- return {k : user_dict [k ] for k in ["id" , "api_key" , "email" , "roles" , "created" , "last_time_used" ]}
47- except KeyError :
48- return {
49- "id" : self .id ,
50- "api_key" : self .api_key ,
51- "email" : self .email ,
52- "roles" : self .get_user_roles ,
53- "created" : self .created ,
54- "last_time_used" : self .last_time_used
55- }
43+ return {
44+ "id" : self .id ,
45+ "api_key" : self .api_key ,
46+ "email" : self .email ,
47+ "roles" : User .get_user_roles (self .id ),
48+ "created" : self .created ,
49+ "last_time_used" : self .last_time_used
50+ }
5651
57- @property
58- def get_user_roles (self ) -> Set [str ]:
59- return set ([role .name for role in self .roles ])
52+ @staticmethod
53+ def get_user_roles (user_id : int ) -> Set [str ]:
54+ with Session () as session :
55+ user = session .query (User ).filter (User .id == user_id ).first ()
56+ return set ([role .name for role in user .roles ])
6057
6158 def has_role (self , required_role : str ) -> bool :
62- return required_role in self .get_user_roles
59+ return required_role in User .get_user_roles ( self . id )
6360
6461 @staticmethod
65- def assign_roles (user : "User" , roles : Optional [Set [str ]]) -> None :
62+ def _assign_roles (user : "User" , roles : Optional [Set [str ]], session ) -> None :
6663 get_structured_logger ("api_user_models" ).info ("setting roles" , roles = roles , user_id = user .id , api_key = user .api_key )
6764 if roles :
65+ db_user = session .query (User ).filter (User .id == user .id ).first ()
6866 roles_to_assign = session .query (UserRole ).filter (UserRole .name .in_ (roles )).all ()
69- user .roles = roles_to_assign
70- session .commit ()
67+ db_user .roles = roles_to_assign
7168 else :
72- user .roles = []
73- session .commit ()
69+ db_user .roles = []
7470
7571 @staticmethod
7672 def find_user (* , # asterisk forces explicit naming of all arguments when calling this method
7773 user_id : Optional [int ] = None , api_key : Optional [str ] = None , user_email : Optional [str ] = None
7874 ) -> "User" :
79- user = (
80- session .query (User )
81- .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
82- .first ()
83- )
75+ # TODO
76+ with Session () as session :
77+ user = (
78+ session .query (User )
79+ .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
80+ .first ()
81+ )
8482 return user if user else None
8583
8684 @staticmethod
8785 def create_user (api_key : str , email : str , user_roles : Optional [Set [str ]] = None ) -> "User" :
86+ # TODO
8887 get_structured_logger ("api_user_models" ).info ("creating user" , api_key = api_key )
89- new_user = User (api_key = api_key , email = email )
90- session .add (new_user )
91- session .commit ()
92- User .assign_roles (new_user , user_roles )
88+ with Session () as session :
89+ new_user = User (api_key = api_key , email = email )
90+ session .add (new_user )
91+ User ._assign_roles (new_user , user_roles , session )
92+ session .commit ()
9393 return new_user
9494
9595 @staticmethod
@@ -100,23 +100,27 @@ def update_user(
100100 roles : Optional [Set [str ]]
101101 ) -> "User" :
102102 get_structured_logger ("api_user_models" ).info ("updating user" , user_id = user .id , new_api_key = api_key )
103- user = User .find_user (user_id = user .id )
104- if user :
105- update_stmt = (
106- update (User )
107- .where (User .id == user .id )
108- .values (api_key = api_key , email = email )
109- )
110- session .execute (update_stmt )
111- session .commit ()
112- User .assign_roles (user , roles )
103+ # TODO
104+ with Session () as session :
105+ user = User .find_user (user_id = user .id )
106+ if user :
107+ update_stmt = (
108+ update (User )
109+ .where (User .id == user .id )
110+ .values (api_key = api_key , email = email )
111+ )
112+ session .execute (update_stmt )
113+ User ._assign_roles (user , roles , session )
114+ session .commit ()
113115 return user
114116
115117 @staticmethod
116118 def delete_user (user_id : int ) -> None :
117119 get_structured_logger ("api_user_models" ).info ("deleting user" , user_id = user_id )
118- session .execute (delete (User ).where (User .id == user_id ))
119- session .commit ()
120+ # TODO
121+ with Session () as session :
122+ session .execute (delete (User ).where (User .id == user_id ))
123+ session .commit ()
120124
121125
122126class UserRole (Base ):
@@ -127,19 +131,23 @@ class UserRole(Base):
127131 @staticmethod
128132 def create_role (name : str ) -> None :
129133 get_structured_logger ("api_user_models" ).info ("creating user role" , role = name )
130- session .execute (
131- f"""
132- INSERT INTO user_role (name)
133- SELECT '{ name } '
134- WHERE NOT EXISTS
135- (SELECT *
136- FROM user_role
137- WHERE name='{ name } ')
138- """
139- )
140- session .commit ()
134+ # TODO
135+ with Session () as session :
136+ session .execute (
137+ f"""
138+ INSERT INTO user_role (name)
139+ SELECT '{ name } '
140+ WHERE NOT EXISTS
141+ (SELECT *
142+ FROM user_role
143+ WHERE name='{ name } ')
144+ """
145+ )
146+ session .commit ()
141147
142148 @staticmethod
143149 def list_all_roles ():
144- roles = session .query (UserRole ).all ()
150+ # TODO
151+ with Session () as session :
152+ roles = session .query (UserRole ).all ()
145153 return [role .name for role in roles ]
0 commit comments