diff --git a/docs/conf.py b/docs/conf.py index 6b8bfe1d5..650e63cca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,32 +18,32 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.ifconfig'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.ifconfig"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Splunk SDK for Python' -copyright = u'2024, Splunk Inc' +project = "Splunk SDK for Python" +copyright = "2024, Splunk Inc" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -56,37 +56,37 @@ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- @@ -95,125 +95,128 @@ # a list of builtin themes. # agogo, default, epub, haiku, nature, pyramid, scrolls, sphinxdoc, traditional -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". html_title = "Splunk SDK for Python API Reference" # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = "Splunk SDK for Python Reference" +# html_short_title = "Splunk SDK for Python Reference" # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['css'] +html_static_path = ["css"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = { +# html_sidebars = { html_sidebars = { - '**': ['globaltoc.html', 'searchbox.html'], + "**": ["globaltoc.html", "searchbox.html"], } # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = False +# html_show_sphinx = False # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = False +# html_show_copyright = False # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'SplunkPythonSDKdoc' +htmlhelp_basename = "SplunkPythonSDKdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'SplunkPythonSDK.tex', u'Splunk SDK for Python Documentation', - u'Splunk Inc.', 'manual'), + ( + "index", + "SplunkPythonSDK.tex", + "Splunk SDK for Python Documentation", + "Splunk Inc.", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- @@ -221,12 +224,17 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'splunkpythonsdk', u'Splunk SDK for Python API Documentation', - [u'Splunk Inc.'], 1) + ( + "index", + "splunkpythonsdk", + "Splunk SDK for Python API Documentation", + ["Splunk Inc."], + 1, + ) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ @@ -235,18 +243,24 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'SplunkPythonSDK', u'Splunk SDK for Python API Documentation', - u'Splunk Inc.', 'SplunkPythonSDK', 'API reference for Splunk SDK for Python.', - 'Miscellaneous'), + ( + "index", + "SplunkPythonSDK", + "Splunk SDK for Python API Documentation", + "Splunk Inc.", + "SplunkPythonSDK", + "API reference for Splunk SDK for Python.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' -autoclass_content = 'both' +autoclass_content = "both" diff --git a/scripts/build-env.py b/scripts/build-env.py index fcf55ae14..7a2703833 100644 --- a/scripts/build-env.py +++ b/scripts/build-env.py @@ -22,18 +22,22 @@ from string import Template DEFAULT_CONFIG = { - 'host': 'localhost', - 'port': '8089', - 'username': 'admin', - 'password': 'changed!', - 'scheme': 'https', - 'version': '8.0' + "host": "localhost", + "port": "8089", + "username": "admin", + "password": "changed!", + "scheme": "https", + "version": "8.0", } -DEFAULT_ENV_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '.env') +DEFAULT_ENV_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "..", ".env" +) ENV_TEMPLATE_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'templates/env.template') + os.path.dirname(os.path.realpath(__file__)), "templates/env.template" +) + # { # "server_roles": { @@ -61,27 +65,29 @@ def build_config(json_string): try: spec_config = json.loads(json_string) - server_config = spec_config['server_roles']['standalone'][0] - splunk_config = server_config['splunk'] + server_config = spec_config["server_roles"]["standalone"][0] + splunk_config = server_config["splunk"] - host, port = parse_hostport(server_config['ports']['8089/tcp']) + host, port = parse_hostport(server_config["ports"]["8089/tcp"]) return { - 'host': host, - 'port': port, - 'username': splunk_config['user_roles']['admin']['username'], - 'password': splunk_config['user_roles']['admin']['password'], - 'version': splunk_config['version'], + "host": host, + "port": port, + "username": splunk_config["user_roles"]["admin"]["username"], + "password": splunk_config["user_roles"]["admin"]["password"], + "version": splunk_config["version"], } except Exception as e: - raise ValueError('Invalid configuration JSON string') from e + raise ValueError("Invalid configuration JSON string") from e + # Source: https://stackoverflow.com/a/53172593 def parse_hostport(host_port): # urlparse() and urlsplit() insists on absolute URLs starting with "//" - result = urllib.parse.urlsplit('//' + host_port) + result = urllib.parse.urlsplit("//" + host_port) return result.hostname, result.port + def run(variable, env_path=None): # read JSON from input # parse the JSON @@ -90,7 +96,7 @@ def run(variable, env_path=None): config = {**DEFAULT_CONFIG, **input_config} # build a env file - with open(ENV_TEMPLATE_PATH, 'r') as f: + with open(ENV_TEMPLATE_PATH, "r") as f: template = Template(f.read()) env_string = template.substitute(config) @@ -101,12 +107,13 @@ def run(variable, env_path=None): return # write the .env file - with open(env_path, 'w') as f: + with open(env_path, "w") as f: f.write(env_string) + if sys.stdin.isatty(): DATA = None else: DATA = sys.stdin.read() -run(DATA, sys.argv[1] if len(sys.argv) > 1 else None) \ No newline at end of file +run(DATA, sys.argv[1] if len(sys.argv) > 1 else None) diff --git a/setup.py b/setup.py index d19a09eb1..65b56812d 100755 --- a/setup.py +++ b/setup.py @@ -20,28 +20,17 @@ setup( author="Splunk, Inc.", - author_email="devinfo@splunk.com", - description="The Splunk Software Development Kit for Python.", - license="http://www.apache.org/licenses/LICENSE-2.0", - name="splunk-sdk", - - packages = ["splunklib", - "splunklib.modularinput", - "splunklib.searchcommands"], - + packages=["splunklib", "splunklib.modularinput", "splunklib.searchcommands"], install_requires=[ - "deprecation", - ], - + "deprecation", + ], url="http://github.com/splunk/splunk-sdk-python", - version=splunklib.__version__, - - classifiers = [ + classifiers=[ "Programming Language :: Python", "Development Status :: 6 - Mature", "Environment :: Other Environment", diff --git a/sitecustomize.py b/sitecustomize.py index eb94c154b..6a23233ad 100644 --- a/sitecustomize.py +++ b/sitecustomize.py @@ -18,6 +18,7 @@ try: import coverage + coverage.process_startup() except: pass diff --git a/splunklib/__init__.py b/splunklib/__init__.py index 5f83b2ac4..5eb9236f6 100644 --- a/splunklib/__init__.py +++ b/splunklib/__init__.py @@ -16,18 +16,20 @@ import logging -DEFAULT_LOG_FORMAT = '%(asctime)s, Level=%(levelname)s, Pid=%(process)s, Logger=%(name)s, File=%(filename)s, ' \ - 'Line=%(lineno)s, %(message)s' -DEFAULT_DATE_FORMAT = '%Y-%m-%d %H:%M:%S %Z' +DEFAULT_LOG_FORMAT = ( + "%(asctime)s, Level=%(levelname)s, Pid=%(process)s, Logger=%(name)s, File=%(filename)s, " + "Line=%(lineno)s, %(message)s" +) +DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S %Z" # To set the logging level of splunklib # ex. To enable debug logs, call this method with parameter 'logging.DEBUG' # default logging level is set to 'WARNING' -def setup_logging(level, log_format=DEFAULT_LOG_FORMAT, date_format=DEFAULT_DATE_FORMAT): - logging.basicConfig(level=level, - format=log_format, - datefmt=date_format) +def setup_logging( + level, log_format=DEFAULT_LOG_FORMAT, date_format=DEFAULT_DATE_FORMAT +): + logging.basicConfig(level=level, format=log_format, datefmt=date_format) __version_info__ = (2, 1, 0) diff --git a/splunklib/binding.py b/splunklib/binding.py index 1470ebdaa..d8cf9121c 100644 --- a/splunklib/binding.py +++ b/splunklib/binding.py @@ -55,12 +55,26 @@ "_encode", "_make_cookie_header", "_NoAuthenticationToken", - "namespace" + "namespace", ] -SENSITIVE_KEYS = ['Authorization', 'Cookie', 'action.email.auth_password', 'auth', 'auth_password', 'clear_password', 'clientId', - 'crc-salt', 'encr_password', 'oldpassword', 'passAuth', 'password', 'session', 'suppressionKey', - 'token'] +SENSITIVE_KEYS = [ + "Authorization", + "Cookie", + "action.email.auth_password", + "auth", + "auth_password", + "clear_password", + "clientId", + "crc-salt", + "encr_password", + "oldpassword", + "passAuth", + "password", + "session", + "suppressionKey", + "token", +] # If you change these, update the docstring # on _authority as well. @@ -82,9 +96,9 @@ def new_f(*args, **kwargs): def mask_sensitive_data(data): - ''' + """ Masked sensitive fields data for logging purpose - ''' + """ if not isinstance(data, dict): try: data = json.loads(data) @@ -193,7 +207,7 @@ class UrlEncoded(str): 'ab c' + UrlEncoded('de f') == UrlEncoded('ab cde f') """ - def __new__(self, val='', skip_encode=False, encode_slash=False): + def __new__(self, val="", skip_encode=False, encode_slash=False): if isinstance(val, UrlEncoded): # Don't urllib.quote something already URL encoded. return val @@ -326,11 +340,14 @@ def wrapper(self, *args, **kwargs): # an AuthenticationError and give up. with _handle_auth_error("Autologin failed."): self.login() - with _handle_auth_error("Authentication Failed! If session token is used, it seems to have been expired."): + with _handle_auth_error( + "Authentication Failed! If session token is used, it seems to have been expired." + ): return request_fun(self, *args, **kwargs) elif he.status == 401 and not self.autologin: raise AuthenticationError( - "Request failed: Session is not logged in.", he) + "Request failed: Session is not logged in.", he + ) else: raise @@ -376,10 +393,10 @@ def _authority(scheme=DEFAULT_SCHEME, host=DEFAULT_HOST, port=DEFAULT_PORT): """ # check if host is an IPv6 address and not enclosed in [ ] - if ':' in host and not (host.startswith('[') and host.endswith(']')): + if ":" in host and not (host.startswith("[") and host.endswith("]")): # IPv6 addresses must be enclosed in [ ] in order to be well # formed. - host = '[' + host + ']' + host = "[" + host + "]" return UrlEncoded(f"{scheme}://{host}:{port}", skip_encode=True) @@ -436,11 +453,11 @@ def namespace(sharing=None, owner=None, app=None, **kwargs): n = binding.namespace(sharing="global", app="search") """ if sharing in ["system"]: - return record({'sharing': sharing, 'owner': "nobody", 'app': "system"}) + return record({"sharing": sharing, "owner": "nobody", "app": "system"}) if sharing in ["global", "app"]: - return record({'sharing': sharing, 'owner': "nobody", 'app': app}) + return record({"sharing": sharing, "owner": "nobody", "app": app}) if sharing in ["user", None]: - return record({'sharing': sharing, 'owner': owner, 'app': app}) + return record({"sharing": sharing, "owner": owner, "app": app}) raise ValueError("Invalid value for argument: 'sharing'") @@ -510,10 +527,16 @@ class Context: """ def __init__(self, handler=None, **kwargs): - self.http = HttpLib(handler, kwargs.get("verify", False), key_file=kwargs.get("key_file"), - cert_file=kwargs.get("cert_file"), context=kwargs.get("context"), - # Default to False for backward compat - retries=kwargs.get("retries", 0), retryDelay=kwargs.get("retryDelay", 10)) + self.http = HttpLib( + handler, + kwargs.get("verify", False), + key_file=kwargs.get("key_file"), + cert_file=kwargs.get("cert_file"), + context=kwargs.get("context"), + # Default to False for backward compat + retries=kwargs.get("retries", 0), + retryDelay=kwargs.get("retryDelay", 10), + ) self.token = kwargs.get("token", _NoAuthenticationToken) if self.token is None: # In case someone explicitly passes token=None self.token = _NoAuthenticationToken @@ -531,7 +554,10 @@ def __init__(self, handler=None, **kwargs): self._self_signed_certificate = kwargs.get("self_signed_certificate", True) # Store any cookies in the self.http._cookies dict - if "cookie" in kwargs and kwargs['cookie'] not in [None, _NoAuthenticationToken]: + if "cookie" in kwargs and kwargs["cookie"] not in [ + None, + _NoAuthenticationToken, + ]: _parse_cookies(kwargs["cookie"], self.http._cookies) def get_cookies(self): @@ -566,21 +592,23 @@ def _auth_headers(self): if self.has_cookies(): return [("Cookie", _make_cookie_header(list(self.get_cookies().items())))] elif self.basic and (self.username and self.password): - token = f'Basic {b64encode(("%s:%s" % (self.username, self.password)).encode("utf-8")).decode("ascii")}' + token = f"Basic {b64encode(('%s:%s' % (self.username, self.password)).encode('utf-8')).decode('ascii')}" elif self.bearerToken: - token = f'Bearer {self.bearerToken}' + token = f"Bearer {self.bearerToken}" elif self.token is _NoAuthenticationToken: token = [] else: # Ensure the token is properly formatted - if self.token.startswith('Splunk '): + if self.token.startswith("Splunk "): token = self.token else: - token = f'Splunk {self.token}' + token = f"Splunk {self.token}" if token: header.append(("Authorization", token)) if self.get_cookies(): - header.append(("Cookie", _make_cookie_header(list(self.get_cookies().items())))) + header.append( + ("Cookie", _make_cookie_header(list(self.get_cookies().items()))) + ) return header @@ -610,7 +638,9 @@ def connect(self): context = ssl.create_default_context() context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 context.check_hostname = not self._self_signed_certificate - context.verify_mode = ssl.CERT_NONE if self._self_signed_certificate else ssl.CERT_REQUIRED + context.verify_mode = ( + ssl.CERT_NONE if self._self_signed_certificate else ssl.CERT_REQUIRED + ) sock = context.wrap_socket(sock, server_hostname=self.host) sock.connect((socket.gethostbyname(self.host), self.port)) return sock @@ -667,15 +697,20 @@ def delete(self, path_segment, owner=None, app=None, sharing=None, **query): c.logout() c.delete('apps/local') # raises AuthenticationError """ - path = self.authority + self._abspath(path_segment, owner=owner, - app=app, sharing=sharing) - logger.debug("DELETE request to %s (body: %s)", path, mask_sensitive_data(query)) + path = self.authority + self._abspath( + path_segment, owner=owner, app=app, sharing=sharing + ) + logger.debug( + "DELETE request to %s (body: %s)", path, mask_sensitive_data(query) + ) response = self.http.delete(path, self._auth_headers, **query) return response @_authentication @_log_duration - def get(self, path_segment, owner=None, app=None, headers=None, sharing=None, **query): + def get( + self, path_segment, owner=None, app=None, headers=None, sharing=None, **query + ): """Performs a GET operation from the REST path segment with the given namespace and query. @@ -730,8 +765,9 @@ def get(self, path_segment, owner=None, app=None, headers=None, sharing=None, ** if headers is None: headers = [] - path = self.authority + self._abspath(path_segment, owner=owner, - app=app, sharing=sharing) + path = self.authority + self._abspath( + path_segment, owner=owner, app=app, sharing=sharing + ) logger.debug("GET request to %s (body: %s)", path, mask_sensitive_data(query)) all_headers = headers + self.additional_headers + self._auth_headers response = self.http.get(path, all_headers, **query) @@ -739,7 +775,9 @@ def get(self, path_segment, owner=None, app=None, headers=None, sharing=None, ** @_authentication @_log_duration - def post(self, path_segment, owner=None, app=None, sharing=None, headers=None, **query): + def post( + self, path_segment, owner=None, app=None, sharing=None, headers=None, **query + ): """Performs a POST operation from the REST path segment with the given namespace and query. @@ -809,7 +847,9 @@ def post(self, path_segment, owner=None, app=None, sharing=None, headers=None, * if headers is None: headers = [] - path = self.authority + self._abspath(path_segment, owner=owner, app=app, sharing=sharing) + path = self.authority + self._abspath( + path_segment, owner=owner, app=app, sharing=sharing + ) logger.debug("POST request to %s (body: %s)", path, mask_sensitive_data(query)) all_headers = headers + self.additional_headers + self._auth_headers @@ -818,8 +858,16 @@ def post(self, path_segment, owner=None, app=None, sharing=None, headers=None, * @_authentication @_log_duration - def request(self, path_segment, method="GET", headers=None, body={}, - owner=None, app=None, sharing=None): + def request( + self, + path_segment, + method="GET", + headers=None, + body={}, + owner=None, + app=None, + sharing=None, + ): """Issues an arbitrary HTTP request to the REST path segment. This method is named to match ``httplib.request``. This function @@ -872,27 +920,28 @@ def request(self, path_segment, method="GET", headers=None, body={}, if headers is None: headers = [] - path = self.authority \ - + self._abspath(path_segment, owner=owner, - app=app, sharing=sharing) + path = self.authority + self._abspath( + path_segment, owner=owner, app=app, sharing=sharing + ) all_headers = headers + self.additional_headers + self._auth_headers - logger.debug("%s request to %s (headers: %s, body: %s)", - method, path, str(mask_sensitive_data(dict(all_headers))), mask_sensitive_data(body)) + logger.debug( + "%s request to %s (headers: %s, body: %s)", + method, + path, + str(mask_sensitive_data(dict(all_headers))), + mask_sensitive_data(body), + ) if body: body = _encode(**body) if method == "GET": - path = path + UrlEncoded('?' + body, skip_encode=True) - message = {'method': method, - 'headers': all_headers} + path = path + UrlEncoded("?" + body, skip_encode=True) + message = {"method": method, "headers": all_headers} else: - message = {'method': method, - 'headers': all_headers, - 'body': body} + message = {"method": method, "headers": all_headers, "body": body} else: - message = {'method': method, - 'headers': all_headers} + message = {"method": method, "headers": all_headers} response = self.http.request(path, message) @@ -918,15 +967,15 @@ def login(self): # Then issue requests... """ - if self.has_cookies() and \ - (not self.username and not self.password): + if self.has_cookies() and (not self.username and not self.password): # If we were passed session cookie(s), but no username or # password, then login is a nop, since we're automatically # logged in. return - if self.token is not _NoAuthenticationToken and \ - (not self.username and not self.password): + if self.token is not _NoAuthenticationToken and ( + not self.username and not self.password + ): # If we were passed a session token, but no username or # password, then login is a nop, since we're automatically # logged in. @@ -948,7 +997,8 @@ def login(self): username=self.username, password=self.password, headers=self.additional_headers, - cookie="1") # In Splunk 6.2+, passing "cookie=1" will return the "set-cookie" header + cookie="1", + ) # In Splunk 6.2+, passing "cookie=1" will return the "set-cookie" header body = response.body.read() session = XML(body).findtext("./sessionKey") @@ -966,8 +1016,7 @@ def logout(self): self.http._cookies = {} return self - def _abspath(self, path_segment, - owner=None, app=None, sharing=None): + def _abspath(self, path_segment, owner=None, app=None, sharing=None): """Qualifies *path_segment* into an absolute path for a URL. If *path_segment* is already absolute, returns it unchanged. @@ -1004,7 +1053,7 @@ def _abspath(self, path_segment, skip_encode = isinstance(path_segment, UrlEncoded) # If path_segment is absolute, escape all forbidden characters # in it and return it. - if path_segment.startswith('/'): + if path_segment.startswith("/"): return UrlEncoded(path_segment, skip_encode=skip_encode) # path_segment is relative, so we need a namespace to build an @@ -1023,7 +1072,9 @@ def _abspath(self, path_segment, oname = "nobody" if ns.owner is None else ns.owner aname = "system" if ns.app is None else ns.app - path = UrlEncoded(f"/servicesNS/{oname}/{aname}/{path_segment}", skip_encode=skip_encode) + path = UrlEncoded( + f"/servicesNS/{oname}/{aname}/{path_segment}", skip_encode=skip_encode + ) return path @@ -1136,6 +1187,7 @@ def __init__(self, message, cause): # } # + # Encode the given kwargs as a query string. This wrapper will also _encode # a list value as a sequence of assignments to the corresponding arg name, # for example an argument such as 'foo=[1,2,3]' will be encoded as @@ -1155,10 +1207,16 @@ def _spliturl(url): parsed_url = parse.urlparse(url) host = parsed_url.hostname port = parsed_url.port - path = '?'.join((parsed_url.path, parsed_url.query)) if parsed_url.query else parsed_url.path + path = ( + "?".join((parsed_url.path, parsed_url.query)) + if parsed_url.query + else parsed_url.path + ) # Strip brackets if its an IPv6 address - if host.startswith('[') and host.endswith(']'): host = host[1:-1] - if port is None: port = DEFAULT_PORT + if host.startswith("[") and host.endswith("]"): + host = host[1:-1] + if port is None: + port = DEFAULT_PORT return parsed_url.scheme, host, port, path @@ -1207,10 +1265,20 @@ class HttpLib: If using the default handler, SSL verification can be disabled by passing verify=False. """ - def __init__(self, custom_handler=None, verify=False, key_file=None, cert_file=None, context=None, retries=0, - retryDelay=10): + def __init__( + self, + custom_handler=None, + verify=False, + key_file=None, + cert_file=None, + context=None, + retries=0, + retryDelay=10, + ): if custom_handler is None: - self.handler = handler(verify=verify, key_file=key_file, cert_file=cert_file, context=context) + self.handler = handler( + verify=verify, key_file=key_file, cert_file=cert_file, context=context + ) else: self.handler = custom_handler self._cookies = {} @@ -1234,15 +1302,16 @@ def delete(self, url, headers=None, **kwargs): its structure). :rtype: ``dict`` """ - if headers is None: headers = [] + if headers is None: + headers = [] if kwargs: # url is already a UrlEncoded. We have to manually declare # the query to be encoded or it will get automatically URL # encoded by being appended to url. - url = url + UrlEncoded('?' + _encode(**kwargs), skip_encode=True) + url = url + UrlEncoded("?" + _encode(**kwargs), skip_encode=True) message = { - 'method': "DELETE", - 'headers': headers, + "method": "DELETE", + "headers": headers, } return self.request(url, message) @@ -1263,13 +1332,14 @@ def get(self, url, headers=None, **kwargs): its structure). :rtype: ``dict`` """ - if headers is None: headers = [] + if headers is None: + headers = [] if kwargs: # url is already a UrlEncoded. We have to manually declare # the query to be encoded or it will get automatically URL # encoded by being appended to url. - url = url + UrlEncoded('?' + _encode(**kwargs), skip_encode=True) - return self.request(url, {'method': "GET", 'headers': headers}) + url = url + UrlEncoded("?" + _encode(**kwargs), skip_encode=True) + return self.request(url, {"method": "GET", "headers": headers}) def post(self, url, headers=None, **kwargs): """Sends a POST request to a URL. @@ -1289,29 +1359,26 @@ def post(self, url, headers=None, **kwargs): its structure). :rtype: ``dict`` """ - if headers is None: headers = [] + if headers is None: + headers = [] # We handle GET-style arguments and an unstructured body. This is here # to support the receivers/stream endpoint. - if 'body' in kwargs: + if "body" in kwargs: # We only use application/x-www-form-urlencoded if there is no other # Content-Type header present. This can happen in cases where we # send requests as application/json, e.g. for KV Store. if len([x for x in headers if x[0].lower() == "content-type"]) == 0: headers.append(("Content-Type", "application/x-www-form-urlencoded")) - body = kwargs.pop('body') + body = kwargs.pop("body") if isinstance(body, dict): - body = _encode(**body).encode('utf-8') + body = _encode(**body).encode("utf-8") if len(kwargs) > 0: - url = url + UrlEncoded('?' + _encode(**kwargs), skip_encode=True) + url = url + UrlEncoded("?" + _encode(**kwargs), skip_encode=True) else: - body = _encode(**kwargs).encode('utf-8') - message = { - 'method': "POST", - 'headers': headers, - 'body': body - } + body = _encode(**kwargs).encode("utf-8") + message = {"method": "POST", "headers": headers, "body": body} return self.request(url, message) def request(self, url, message, **kwargs): @@ -1372,10 +1439,10 @@ class ResponseReader(io.RawIOBase): def __init__(self, response, connection=None): self._response = response self._connection = connection - self._buffer = b'' + self._buffer = b"" def __str__(self): - return str(self.read(), 'UTF-8') + return str(self.read(), "UTF-8") @property def empty(self): @@ -1410,18 +1477,18 @@ def read(self, size=None): """ r = self._buffer - self._buffer = b'' + self._buffer = b"" if size is not None: size -= len(r) r = r + self._response.read(size) return r def readable(self): - """ Indicates that the response reader is readable.""" + """Indicates that the response reader is readable.""" return True def readinto(self, byte_array): - """ Read data into a byte array, upto the size of the byte array. + """Read data into a byte array, upto the size of the byte array. :param byte_array: A byte array/memory view to pour bytes into. :type byte_array: ``bytearray`` or ``memoryview`` @@ -1452,18 +1519,21 @@ def handler(key_file=None, cert_file=None, timeout=None, verify=False, context=N def connect(scheme, host, port): kwargs = {} - if timeout is not None: kwargs['timeout'] = timeout + if timeout is not None: + kwargs["timeout"] = timeout if scheme == "http": return client.HTTPConnection(host, port, **kwargs) if scheme == "https": - if key_file is not None: kwargs['key_file'] = key_file - if cert_file is not None: kwargs['cert_file'] = cert_file + if key_file is not None: + kwargs["key_file"] = key_file + if cert_file is not None: + kwargs["cert_file"] = cert_file if not verify: - kwargs['context'] = ssl._create_unverified_context() # nosemgrep + kwargs["context"] = ssl._create_unverified_context() # nosemgrep elif context: # verify is True in elif branch and context is not None - kwargs['context'] = context + kwargs["context"] = context return client.HTTPSConnection(host, port, **kwargs) raise ValueError(f"unsupported scheme: {scheme}") @@ -1489,7 +1559,10 @@ def request(url, message, **kwargs): if timeout is not None: connection.sock.settimeout(timeout) response = connection.getresponse() - is_keepalive = "keep-alive" in response.getheader("connection", default="close").lower() + is_keepalive = ( + "keep-alive" + in response.getheader("connection", default="close").lower() + ) finally: if not is_keepalive: connection.close() diff --git a/splunklib/client.py b/splunklib/client.py index 15cef7ee4..72cefc262 100644 --- a/splunklib/client.py +++ b/splunklib/client.py @@ -70,9 +70,16 @@ from . import data from .data import record -from .binding import (AuthenticationError, Context, HTTPError, UrlEncoded, - _encode, _make_cookie_header, _NoAuthenticationToken, - namespace) +from .binding import ( + AuthenticationError, + Context, + HTTPError, + UrlEncoded, + _encode, + _make_cookie_header, + _NoAuthenticationToken, + namespace, +) logger = logging.getLogger(__name__) @@ -83,7 +90,7 @@ "IncomparableException", "Service", "namespace", - "AuthenticationError" + "AuthenticationError", ] PATH_APPS = "apps/local/" @@ -175,7 +182,7 @@ def _trailing(template, *targets): n = s.find(t) if n == -1: raise ValueError("Target " + t + " not found in template.") - s = s[n + len(t):] + s = s[n + len(t) :] return s @@ -183,13 +190,17 @@ def _trailing(template, *targets): def _filter_content(content, *args): if len(args) > 0: return record((k, content[k]) for k in args) - return record((k, v) for k, v in content.items() - if k not in ['eai:acl', 'eai:attributes', 'type']) + return record( + (k, v) + for k, v in content.items() + if k not in ["eai:acl", "eai:attributes", "type"] + ) # Construct a resource path from the given base path + resource name def _path(base, name): - if not base.endswith('/'): base = base + '/' + if not base.endswith("/"): + base = base + "/" return base + name @@ -197,26 +208,27 @@ def _path(base, name): # this will ultimately be sent to an xml ElementTree so we # should use the xmlcharrefreplace option def _load_atom(response, match=None): - return data.load(response.body.read() - .decode('utf-8', 'xmlcharrefreplace'), match) + return data.load(response.body.read().decode("utf-8", "xmlcharrefreplace"), match) # Load an array of atom entries from the body of the given response def _load_atom_entries(response): r = _load_atom(response) - if 'feed' in r: + if "feed" in r: # Need this to handle a random case in the REST API - if r.feed.get('totalResults') in [0, '0']: + if r.feed.get("totalResults") in [0, "0"]: return [] - entries = r.feed.get('entry', None) - if entries is None: return None + entries = r.feed.get("entry", None) + if entries is None: + return None return entries if isinstance(entries, list) else [entries] # Unlike most other endpoints, the jobs endpoint does not return # its state wrapped in another element, but at the top level. # For example, in XML, it returns ... instead of # .... - entries = r.get('entry', None) - if entries is None: return None + entries = r.get("entry", None) + if entries is None: + return None return entries if isinstance(entries, list) else [entries] @@ -224,63 +236,69 @@ def _load_atom_entries(response): def _load_sid(response, output_mode): if output_mode == "json": json_obj = json.loads(response.body.read()) - return json_obj.get('sid') + return json_obj.get("sid") return _load_atom(response).response.sid # Parse the given atom entry record into a generic entity state record def _parse_atom_entry(entry): - title = entry.get('title', None) + title = entry.get("title", None) - elink = entry.get('link', []) + elink = entry.get("link", []) elink = elink if isinstance(elink, list) else [elink] links = record((link.rel, link.href) for link in elink) # Retrieve entity content values - content = entry.get('content', {}) + content = entry.get("content", {}) # Host entry metadata metadata = _parse_atom_metadata(content) # Filter some of the noise out of the content record - content = record((k, v) for k, v in content.items() - if k not in ['eai:acl', 'eai:attributes']) + content = record( + (k, v) for k, v in content.items() if k not in ["eai:acl", "eai:attributes"] + ) - if 'type' in content: - if isinstance(content['type'], list): - content['type'] = [t for t in content['type'] if t != 'text/xml'] + if "type" in content: + if isinstance(content["type"], list): + content["type"] = [t for t in content["type"] if t != "text/xml"] # Unset type if it was only 'text/xml' - if len(content['type']) == 0: - content.pop('type', None) + if len(content["type"]) == 0: + content.pop("type", None) # Flatten 1 element list - if len(content['type']) == 1: - content['type'] = content['type'][0] + if len(content["type"]) == 1: + content["type"] = content["type"][0] else: - content.pop('type', None) + content.pop("type", None) - return record({ - 'title': title, - 'links': links, - 'access': metadata.access, - 'fields': metadata.fields, - 'content': content, - 'updated': entry.get("updated") - }) + return record( + { + "title": title, + "links": links, + "access": metadata.access, + "fields": metadata.fields, + "content": content, + "updated": entry.get("updated"), + } + ) # Parse the metadata fields out of the given atom entry content record def _parse_atom_metadata(content): # Hoist access metadata - access = content.get('eai:acl', None) + access = content.get("eai:acl", None) # Hoist content metadata (and cleanup some naming) - attributes = content.get('eai:attributes', {}) - fields = record({ - 'required': attributes.get('requiredFields', []), - 'optional': attributes.get('optionalFields', []), - 'wildcard': attributes.get('wildcardFields', [])}) + attributes = content.get("eai:attributes", {}) + fields = record( + { + "required": attributes.get("requiredFields", []), + "optional": attributes.get("optionalFields", []), + "wildcard": attributes.get("wildcardFields", []), + } + ) - return record({'access': access, 'fields': fields}) + return record({"access": access, "fields": fields}) # kwargs: scheme, host, port, app, owner, username, password @@ -533,7 +551,9 @@ def modular_input_kinds(self): """ if self.splunk_version >= (5,): return ReadOnlyCollection(self, PATH_MODULAR_INPUTS, item=ModularInputKind) - raise IllegalOperationException("Modular inputs are not supported before Splunk version 5.") + raise IllegalOperationException( + "Modular inputs are not supported before Splunk version 5." + ) @property def storage_passwords(self): @@ -583,7 +603,11 @@ def restart(self, timeout=None): :param timeout: A timeout period, in seconds. :type timeout: ``integer`` """ - msg = {"value": "Restart requested by " + self.username + "via the Splunk SDK for Python"} + msg = { + "value": "Restart requested by " + + self.username + + "via the Splunk SDK for Python" + } # This message will be deleted once the server actually restarts. self.messages.create(name="restart_required", **msg) result = self.post("/services/server/control/restart") @@ -608,15 +632,15 @@ def restart_required(self): """ response = self.get("messages").body.read() - messages = data.load(response)['feed'] - if 'entry' not in messages: + messages = data.load(response)["feed"] + if "entry" not in messages: result = False else: - if isinstance(messages['entry'], dict): - titles = [messages['entry']['title']] + if isinstance(messages["entry"], dict): + titles = [messages["entry"]["title"]] else: - titles = [x['title'] for x in messages['entry']] - result = 'restart_required' in titles + titles = [x["title"] for x in messages["entry"]] + result = "restart_required" in titles return result @property @@ -696,24 +720,26 @@ def splunk_version(self): :return: A ``tuple`` of ``integers``. """ if self._splunk_version is None: - self._splunk_version = tuple(int(p) for p in self.info['version'].split('.')) + self._splunk_version = tuple( + int(p) for p in self.info["version"].split(".") + ) return self._splunk_version @property def splunk_instance(self): - if self._instance_type is None : + if self._instance_type is None: splunk_info = self.info - if hasattr(splunk_info, 'instance_type') : - self._instance_type = splunk_info['instance_type'] + if hasattr(splunk_info, "instance_type"): + self._instance_type = splunk_info["instance_type"] else: - self._instance_type = '' + self._instance_type = "" return self._instance_type @property def disable_v2_api(self): - if self.splunk_instance.lower() == 'cloud': - return self.splunk_version < (9,0,2209) - return self.splunk_version < (9,0,2) + if self.splunk_instance.lower() == "cloud": + return self.splunk_version < (9, 0, 2209) + return self.splunk_version < (9, 0, 2) @property def kvstore_owner(self): @@ -742,7 +768,7 @@ def kvstore(self): :return: A :class:`KVStoreCollections` collection of :class:`KVStoreCollection` entities. """ - self.namespace['owner'] = self.kvstore_owner + self.namespace["owner"] = self.kvstore_owner return KVStoreCollections(self) @property @@ -779,7 +805,9 @@ def get_api_version(self, path): # For example, "/services/search/jobs" is using API v1 api_version = 1 - versionSearch = re.search(r'(?:servicesNS\/[^/]+\/[^/]+|services)\/[^/]+\/v(\d+)\/', path) + versionSearch = re.search( + r"(?:servicesNS\/[^/]+\/[^/]+|services)\/[^/]+\/v(\d+)\/", path + ) if versionSearch: api_version = int(versionSearch.group(1)) @@ -838,13 +866,14 @@ def get(self, path_segment="", owner=None, app=None, sharing=None, **query): # self.path to the Endpoint is relative in the SDK, so passing # owner, app, sharing, etc. along will produce the correct # namespace in the final request. - if path_segment.startswith('/'): + if path_segment.startswith("/"): path = path_segment else: - if not self.path.endswith('/') and path_segment != "": - self.path = self.path + '/' - path = self.service._abspath(self.path + path_segment, owner=owner, - app=app, sharing=sharing) + if not self.path.endswith("/") and path_segment != "": + self.path = self.path + "/" + path = self.service._abspath( + self.path + path_segment, owner=owner, app=app, sharing=sharing + ) # ^-- This was "%s%s" % (self.path, path_segment). # That doesn't work, because self.path may be UrlEncoded. @@ -859,13 +888,13 @@ def get(self, path_segment="", owner=None, app=None, sharing=None, **query): if api_version == 1: if isinstance(path, UrlEncoded): - path = UrlEncoded(path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True) + path = UrlEncoded( + path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True + ) else: path = path.replace(PATH_JOBS_V2, PATH_JOBS) - return self.service.get(path, - owner=owner, app=app, sharing=sharing, - **query) + return self.service.get(path, owner=owner, app=app, sharing=sharing, **query) def post(self, path_segment="", owner=None, app=None, sharing=None, **query): """Performs a POST operation on the path segment relative to this endpoint. @@ -916,12 +945,14 @@ def post(self, path_segment="", owner=None, app=None, sharing=None, **query): s.logout() apps.get() # raises AuthenticationError """ - if path_segment.startswith('/'): + if path_segment.startswith("/"): path = path_segment else: - if not self.path.endswith('/') and path_segment != "": - self.path = self.path + '/' - path = self.service._abspath(self.path + path_segment, owner=owner, app=app, sharing=sharing) + if not self.path.endswith("/") and path_segment != "": + self.path = self.path + "/" + path = self.service._abspath( + self.path + path_segment, owner=owner, app=app, sharing=sharing + ) # Get the API version from the path api_version = self.get_api_version(path) @@ -934,7 +965,9 @@ def post(self, path_segment="", owner=None, app=None, sharing=None, **query): if api_version == 1: if isinstance(path, UrlEncoded): - path = UrlEncoded(path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True) + path = UrlEncoded( + path.replace(PATH_JOBS_V2, PATH_JOBS), skip_encode=True + ) else: path = path.replace(PATH_JOBS_V2, PATH_JOBS) @@ -971,6 +1004,7 @@ class Entity(Endpoint): does not contact the server. If you think the values on the server have changed, call the :meth:`Entity.refresh` method. """ + # Not every endpoint in the API is an Entity or a Collection. For # example, a saved search at saved/searches/{name} has an additional # method saved/searches/{name}/scheduled_times, but this isn't an @@ -1005,8 +1039,8 @@ class Entity(Endpoint): def __init__(self, service, path, **kwargs): Endpoint.__init__(self, service, path) self._state = None - if not kwargs.get('skip_refresh', False): - self.refresh(kwargs.get('state', None)) # "Prefresh" + if not kwargs.get("skip_refresh", False): + self.refresh(kwargs.get("state", None)) # "Prefresh" def __contains__(self, item): try: @@ -1036,7 +1070,9 @@ def __eq__(self, other): Makes no roundtrips to the server. """ - raise IncomparableException(f"Equality is undefined for objects of class {self.__class__.__name__}") + raise IncomparableException( + f"Equality is undefined for objects of class {self.__class__.__name__}" + ) def __getattr__(self, key): # Called when an attribute was not found by the normal method. In this @@ -1058,10 +1094,11 @@ def __getitem__(self, key): def _load_atom_entry(self, response): elem = _load_atom(response, XNAME_ENTRY) if isinstance(elem, list): - apps = [ele.entry.content.get('eai:appName') for ele in elem] + apps = [ele.entry.content.get("eai:appName") for ele in elem] raise AmbiguousReferenceException( - f"Fetch from server returned multiple entries for name '{elem[0].entry.title}' in apps {apps}.") + f"Fetch from server returned multiple entries for name '{elem[0].entry.title}' in apps {apps}." + ) return elem.entry # Load the entity state record from the given response @@ -1096,13 +1133,17 @@ def _proper_namespace(self, owner=None, app=None, sharing=None): :return: """ if owner is None and app is None and sharing is None: # No namespace provided - if self._state is not None and 'access' in self._state: - return (self._state.access.owner, - self._state.access.app, - self._state.access.sharing) - return (self.service.namespace['owner'], - self.service.namespace['app'], - self.service.namespace['sharing']) + if self._state is not None and "access" in self._state: + return ( + self._state.access.owner, + self._state.access.app, + self._state.access.sharing, + ) + return ( + self.service.namespace["owner"], + self.service.namespace["app"], + self.service.namespace["sharing"], + ) return owner, app, sharing def delete(self): @@ -1115,7 +1156,9 @@ def get(self, path_segment="", owner=None, app=None, sharing=None, **query): def post(self, path_segment="", owner=None, app=None, sharing=None, **query): owner, app, sharing = self._proper_namespace(owner, app, sharing) - return super().post(path_segment, owner=owner, app=app, sharing=sharing, **query) + return super().post( + path_segment, owner=owner, app=app, sharing=sharing, **query + ) def refresh(self, state=None): """Refreshes the state of this entity. @@ -1198,14 +1241,15 @@ def name(self): return self.state.title def read(self, response): - """ Reads the current state of the entity from the server. """ + """Reads the current state of the entity from the server.""" results = self._load_state(response) # In lower layers of the SDK, we end up trying to URL encode # text to be dispatched via HTTP. However, these links are already # URL encoded when they arrive, and we need to mark them as such. - unquoted_links = dict((k, UrlEncoded(v, skip_encode=True)) - for k, v in results['links'].items()) - results['links'] = unquoted_links + unquoted_links = dict( + (k, UrlEncoded(v, skip_encode=True)) for k, v in results["links"].items() + ) + results["links"] = unquoted_links return results def reload(self): @@ -1249,7 +1293,8 @@ def state(self): :return: A ``dict`` containing fields and metadata for the entity. """ - if self._state is None: self.refresh() + if self._state is None: + self.refresh() return self._state def update(self, **kwargs): @@ -1282,8 +1327,10 @@ def update(self, **kwargs): # expected behavior of updating this Entity. Therefore, we # check for 'name' in kwargs and throw an error if it is # there. - if 'name' in kwargs: - raise IllegalOperationException('Cannot update the name of an Entity via the REST API.') + if "name" in kwargs: + raise IllegalOperationException( + "Cannot update the name of an Entity via the REST API." + ) self.post(**kwargs) return self @@ -1375,7 +1422,8 @@ def __getitem__(self, key): entries = self._load_list(response) if len(entries) > 1: raise AmbiguousReferenceException( - f"Found multiple entities named '{key}'; please specify a namespace.") + f"Found multiple entities named '{key}'; please specify a namespace." + ) if len(entries) == 0: raise KeyError(key) return entries[0] @@ -1445,10 +1493,10 @@ def _entity_path(self, state): # overloaded by Configurations, which has to switch its # entities' endpoints from its own properties/ to configs/. raw_path = parse.unquote(state.links.alternate) - if 'servicesNS/' in raw_path: - return _trailing(raw_path, 'servicesNS/', '/', '/') - if 'services/' in raw_path: - return _trailing(raw_path, 'services/') + if "servicesNS/" in raw_path: + return _trailing(raw_path, "servicesNS/", "/", "/") + if "services/" in raw_path: + return _trailing(raw_path, "services/") return raw_path def _load_list(self, response): @@ -1476,14 +1524,12 @@ def _load_list(self, response): # splunkd returns something that doesn't match # . entries = _load_atom_entries(response) - if entries is None: return [] + if entries is None: + return [] entities = [] for entry in entries: state = _parse_atom_entry(entry) - entity = self.item( - self.service, - self._entity_path(state), - state=state) + entity = self.item(self.service, self._entity_path(state), state=state) entities.append(entity) return entities @@ -1577,7 +1623,14 @@ def iter(self, offset=0, count=None, pagesize=None, **kwargs): if pagesize is None or N < pagesize: break offset += N - logger.debug("pagesize=%d, fetched=%d, offset=%d, N=%d, kwargs=%s", pagesize, fetched, offset, N, kwargs) + logger.debug( + "pagesize=%d, fetched=%d, offset=%d, N=%d, kwargs=%s", + pagesize, + fetched, + offset, + N, + kwargs, + ) # kwargs: count, offset, search, sort_dir, sort_key, sort_mode def list(self, count=None, **kwargs): @@ -1687,11 +1740,11 @@ def create(self, name, **params): """ if not isinstance(name, str): raise InvalidNameException(f"{name} is not a valid name for an entity.") - if 'namespace' in params: - namespace = params.pop('namespace') - params['owner'] = namespace.owner - params['app'] = namespace.app - params['sharing'] = namespace.sharing + if "namespace" in params: + namespace = params.pop("namespace") + params["owner"] = namespace.owner + params["app"] = namespace.app + params["sharing"] = namespace.sharing response = self.post(name=name, **params) atom = _load_atom(response, XNAME_ENTRY) if atom is None: @@ -1700,10 +1753,7 @@ def create(self, name, **params): return self[name] entry = atom.entry state = _parse_atom_entry(entry) - entity = self.item( - self.service, - self._entity_path(state), - state=state) + entity = self.item(self.service, self._entity_path(state), state=state) return entity def delete(self, name, **params): @@ -1732,11 +1782,11 @@ def delete(self, name, **params): assert 'my_saved_search' not in saved_searches """ name = UrlEncoded(name, encode_slash=True) - if 'namespace' in params: - namespace = params.pop('namespace') - params['owner'] = namespace.owner - params['app'] = namespace.app - params['sharing'] = namespace.sharing + if "namespace" in params: + namespace = params.pop("namespace") + params["owner"] = namespace.owner + params["app"] = namespace.app + params["sharing"] = namespace.sharing try: self.service.delete(_path(self.path, name), **params) except HTTPError as he: @@ -1799,15 +1849,14 @@ def get(self, name="", owner=None, app=None, sharing=None, **query): class ConfigurationFile(Collection): - """This class contains all of the stanzas from one configuration file. - """ + """This class contains all of the stanzas from one configuration file.""" # __init__'s arguments must match those of an Entity, not a # Collection, since it is being created as the elements of a # Configurations, which is a Collection subclass. def __init__(self, service, path, **kwargs): Collection.__init__(self, service, path, item=Stanza) - self.name = kwargs['state']['title'] + self.name = kwargs["state"]["title"] class Configurations(Collection): @@ -1821,7 +1870,7 @@ class Configurations(Collection): def __init__(self, service): Collection.__init__(self, service, PATH_PROPERTIES, item=ConfigurationFile) - if self.service.namespace.owner == '-' or self.service.namespace.app == '-': + if self.service.namespace.owner == "-" or self.service.namespace.app == "-": raise ValueError("Configurations cannot have wildcards in namespace.") def __getitem__(self, key): @@ -1834,7 +1883,9 @@ def __getitem__(self, key): # that multiple entities means a name collision, so we have to override it here. try: self.get(key) - return ConfigurationFile(self.service, PATH_CONF % key, state={'title': key}) + return ConfigurationFile( + self.service, PATH_CONF % key, state={"title": key} + ) except HTTPError as he: if he.status == 404: # No entity matching key raise KeyError(key) @@ -1853,7 +1904,7 @@ def __contains__(self, key): raise def create(self, name): - """ Creates a configuration file named *name*. + """Creates a configuration file named *name*. If there is already a configuration file with that name, the existing file is returned. @@ -1872,18 +1923,24 @@ def create(self, name): if response.status == 303: return self[name] if response.status == 201: - return ConfigurationFile(self.service, PATH_CONF % name, item=Stanza, state={'title': name}) - raise ValueError(f"Unexpected status code {response.status} returned from creating a stanza") + return ConfigurationFile( + self.service, PATH_CONF % name, item=Stanza, state={"title": name} + ) + raise ValueError( + f"Unexpected status code {response.status} returned from creating a stanza" + ) def delete(self, key): """Raises `IllegalOperationException`.""" - raise IllegalOperationException("Cannot delete configuration files from the REST API.") + raise IllegalOperationException( + "Cannot delete configuration files from the REST API." + ) def _entity_path(self, state): # Overridden to make all the ConfigurationFile objects # returned refer to the configs/ path instead of the # properties/ path used by Configrations. - return PATH_CONF % state['title'] + return PATH_CONF % state["title"] class Stanza(Entity): @@ -1905,35 +1962,39 @@ def __len__(self): # The stanza endpoint returns all the keys at the same level in the XML as the eai information # and 'disabled', so to get an accurate length, we have to filter those out and have just # the stanza keys. - return len([x for x in self._state.content.keys() - if not x.startswith('eai') and x != 'disabled']) + return len( + [ + x + for x in self._state.content.keys() + if not x.startswith("eai") and x != "disabled" + ] + ) class StoragePassword(Entity): - """This class contains a storage password. - """ + """This class contains a storage password.""" def __init__(self, service, path, **kwargs): - state = kwargs.get('state', None) - kwargs['skip_refresh'] = kwargs.get('skip_refresh', state is not None) + state = kwargs.get("state", None) + kwargs["skip_refresh"] = kwargs.get("skip_refresh", state is not None) super().__init__(service, path, **kwargs) self._state = state @property def clear_password(self): - return self.content.get('clear_password') + return self.content.get("clear_password") @property def encrypted_password(self): - return self.content.get('encr_password') + return self.content.get("encr_password") @property def realm(self): - return self.content.get('realm') + return self.content.get("realm") @property def username(self): - return self.content.get('username') + return self.content.get("username") class StoragePasswords(Collection): @@ -1942,12 +2003,12 @@ class StoragePasswords(Collection): """ def __init__(self, service): - if service.namespace.owner == '-' or service.namespace.app == '-': + if service.namespace.owner == "-" or service.namespace.app == "-": raise ValueError("StoragePasswords cannot have wildcards in namespace.") super().__init__(service, PATH_STORAGE_PASSWORDS, item=StoragePassword) def create(self, password, username, realm=None): - """ Creates a storage password. + """Creates a storage password. A `StoragePassword` can be identified by , or by : if the optional realm parameter is also provided. @@ -1970,11 +2031,15 @@ def create(self, password, username, realm=None): response = self.post(password=password, realm=realm, name=username) if response.status != 201: - raise ValueError(f"Unexpected status code {response.status} returned from creating a stanza") + raise ValueError( + f"Unexpected status code {response.status} returned from creating a stanza" + ) entries = _load_atom_entries(response) state = _parse_atom_entry(entries[0]) - storage_password = StoragePassword(self.service, self._entity_path(state), state=state, skip_refresh=True) + storage_password = StoragePassword( + self.service, self._entity_path(state), state=state, skip_refresh=True + ) return storage_password @@ -1999,7 +2064,11 @@ def delete(self, username, realm=None): name = username else: # Encode each component separately - name = UrlEncoded(realm, encode_slash=True) + ":" + UrlEncoded(username, encode_slash=True) + name = ( + UrlEncoded(realm, encode_slash=True) + + ":" + + UrlEncoded(username, encode_slash=True) + ) # Append the : expected at the end of the name if name[-1] != ":": @@ -2032,7 +2101,7 @@ def count(self): :return: The triggered alert count. :rtype: ``integer`` """ - return int(self.content.get('triggered_alert_count', 0)) + return int(self.content.get("triggered_alert_count", 0)) class Indexes(Collection): @@ -2041,16 +2110,16 @@ class Indexes(Collection): """ def get_default(self): - """ Returns the name of the default index. + """Returns the name of the default index. :return: The name of the default index. """ - index = self['_audit'] - return index['defaultDatabase'] + index = self["_audit"] + return index["defaultDatabase"] def delete(self, name): - """ Deletes a given index. + """Deletes a given index. **Note**: This method is only supported in Splunk 5.0 and later. @@ -2060,8 +2129,10 @@ def delete(self, name): if self.service.splunk_version >= (5,): Collection.delete(self, name) else: - raise IllegalOperationException("Deleting indexes via the REST API is " - "not supported before Splunk version 5.") + raise IllegalOperationException( + "Deleting indexes via the REST API is " + "not supported before Splunk version 5." + ) class Index(Entity): @@ -2084,13 +2155,22 @@ def attach(self, host=None, source=None, sourcetype=None): :return: A writable socket. """ - args = {'index': self.name} - if host is not None: args['host'] = host - if source is not None: args['source'] = source - if sourcetype is not None: args['sourcetype'] = sourcetype - path = UrlEncoded(PATH_RECEIVERS_STREAM + "?" + parse.urlencode(args), skip_encode=True) + args = {"index": self.name} + if host is not None: + args["host"] = host + if source is not None: + args["source"] = source + if sourcetype is not None: + args["sourcetype"] = sourcetype + path = UrlEncoded( + PATH_RECEIVERS_STREAM + "?" + parse.urlencode(args), skip_encode=True + ) - cookie_header = self.service.token if self.service.token is _NoAuthenticationToken else self.service.token.replace("Splunk ", "") + cookie_header = ( + self.service.token + if self.service.token is _NoAuthenticationToken + else self.service.token.replace("Splunk ", "") + ) cookie_or_auth_header = f"Authorization: Splunk {cookie_header}\r\n" # If we have cookie(s), use them instead of "Authorization: ..." @@ -2102,12 +2182,14 @@ def attach(self, host=None, source=None, sourcetype=None): # the connection open and use the Splunk extension headers to note # the input mode sock = self.service.connect() - headers = [f"POST {str(self.service._abspath(path))} HTTP/1.1\r\n".encode('utf-8'), - f"Host: {self.service.host}:{int(self.service.port)}\r\n".encode('utf-8'), - b"Accept-Encoding: identity\r\n", - cookie_or_auth_header.encode('utf-8'), - b"X-Splunk-Input-Mode: Streaming\r\n", - b"\r\n"] + headers = [ + f"POST {str(self.service._abspath(path))} HTTP/1.1\r\n".encode("utf-8"), + f"Host: {self.service.host}:{int(self.service.port)}\r\n".encode("utf-8"), + b"Accept-Encoding: identity\r\n", + cookie_or_auth_header.encode("utf-8"), + b"X-Splunk-Input-Mode: Streaming\r\n", + b"\r\n", + ] for h in headers: sock.write(h) @@ -2161,8 +2243,8 @@ def clean(self, timeout=60): """ self.refresh() - tds = self['maxTotalDataSizeMB'] - ftp = self['frozenTimePeriodInSecs'] + tds = self["maxTotalDataSizeMB"] + ftp = self["frozenTimePeriodInSecs"] was_disabled_initially = self.disabled try: if not was_disabled_initially and self.service.splunk_version < (5,): @@ -2175,13 +2257,14 @@ def clean(self, timeout=60): # Wait until event count goes to 0. start = datetime.now() diff = timedelta(seconds=timeout) - while self.content.totalEventCount != '0' and datetime.now() < start + diff: + while self.content.totalEventCount != "0" and datetime.now() < start + diff: sleep(1) self.refresh() - if self.content.totalEventCount != '0': + if self.content.totalEventCount != "0": raise OperationError( - f"Cleaning index {self.name} took longer than {timeout} seconds; timing out.") + f"Cleaning index {self.name} took longer than {timeout} seconds; timing out." + ) finally: # Restore original values self.update(maxTotalDataSizeMB=tds, frozenTimePeriodInSecs=ftp) @@ -2213,10 +2296,13 @@ def submit(self, event, host=None, source=None, sourcetype=None): :return: The :class:`Index`. """ - args = {'index': self.name} - if host is not None: args['host'] = host - if source is not None: args['source'] = source - if sourcetype is not None: args['sourcetype'] = sourcetype + args = {"index": self.name} + if host is not None: + args["host"] = host + if source is not None: + args["source"] = source + if sourcetype is not None: + args["sourcetype"] = sourcetype self.service.post(PATH_RECEIVERS_SIMPLE, body=event, **args) return self @@ -2236,8 +2322,8 @@ def upload(self, filename, **kwargs): :return: The :class:`Index`. """ - kwargs['index'] = self.name - path = 'data/inputs/oneshot' + kwargs["index"] = self.name + path = "data/inputs/oneshot" self.service.post(path, name=filename, **kwargs) return self @@ -2255,20 +2341,20 @@ def __init__(self, service, path, kind=None, **kwargs): # and "splunktcp" (which is "tcp/cooked"). Entity.__init__(self, service, path, **kwargs) if kind is None: - path_segments = path.split('/') - i = path_segments.index('inputs') + 1 - if path_segments[i] == 'tcp': - self.kind = path_segments[i] + '/' + path_segments[i + 1] + path_segments = path.split("/") + i = path_segments.index("inputs") + 1 + if path_segments[i] == "tcp": + self.kind = path_segments[i] + "/" + path_segments[i + 1] else: self.kind = path_segments[i] else: self.kind = kind # Handle old input kind names. - if self.kind == 'tcp': - self.kind = 'tcp/raw' - if self.kind == 'splunktcp': - self.kind = 'tcp/cooked' + if self.kind == "tcp": + self.kind = "tcp/raw" + if self.kind == "splunktcp": + self.kind = "tcp/cooked" def update(self, **kwargs): """Updates the server with any changes you've made to the current input @@ -2283,7 +2369,7 @@ def update(self, **kwargs): """ # UDP and TCP inputs require special handling due to their restrictToHost # field. For all other inputs kinds, we can dispatch to the superclass method. - if self.kind not in ['tcp', 'splunktcp', 'tcp/raw', 'tcp/cooked', 'udp']: + if self.kind not in ["tcp", "splunktcp", "tcp/raw", "tcp/cooked", "udp"]: return super().update(**kwargs) else: # The behavior of restrictToHost is inconsistent across input kinds and versions of Splunk. @@ -2300,10 +2386,12 @@ def update(self, **kwargs): # cause it to change in Splunk 5.0 and 5.0.1. to_update = kwargs.copy() - if 'restrictToHost' in kwargs: - raise IllegalOperationException("Cannot set restrictToHost on an existing input with the SDK.") - if 'restrictToHost' in self._state.content and self.kind != 'udp': - to_update['restrictToHost'] = self._state.content['restrictToHost'] + if "restrictToHost" in kwargs: + raise IllegalOperationException( + "Cannot set restrictToHost on an existing input with the SDK." + ) + if "restrictToHost" in self._state.content and self.kind != "udp": + to_update["restrictToHost"] = self._state.content["restrictToHost"] # Do the actual update operation. return super().update(**to_update) @@ -2333,7 +2421,9 @@ def __getitem__(self, key): response = self.get(self.kindpath(kind) + "/" + key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException(f"Found multiple inputs of kind {kind} named {key}.") + raise AmbiguousReferenceException( + f"Found multiple inputs of kind {kind} named {key}." + ) if len(entries) == 0: raise KeyError((key, kind)) return entries[0] @@ -2352,13 +2442,18 @@ def __getitem__(self, key): response = self.get(kind + "/" + key) entries = self._load_list(response) if len(entries) > 1: - raise AmbiguousReferenceException(f"Found multiple inputs of kind {kind} named {key}.") + raise AmbiguousReferenceException( + f"Found multiple inputs of kind {kind} named {key}." + ) if len(entries) == 0: pass else: - if candidate is not None: # Already found at least one candidate + if ( + candidate is not None + ): # Already found at least one candidate raise AmbiguousReferenceException( - f"Found multiple inputs named {key}, please specify a kind") + f"Found multiple inputs named {key}, please specify a kind" + ) candidate = entries[0] except HTTPError as he: if he.status == 404: @@ -2441,7 +2536,9 @@ def create(self, name, kind, **kwargs): name = UrlEncoded(name, encode_slash=True) path = _path( self.path + kindpath, - f"{kwargs['restrictToHost']}:{name}" if 'restrictToHost' in kwargs else name + f"{kwargs['restrictToHost']}:{name}" + if "restrictToHost" in kwargs + else name, ) return Input(self.service, path, kind) @@ -2521,16 +2618,16 @@ def _get_kind_list(self, subpath=None): subpath = [] kinds = [] - response = self.get('/'.join(subpath)) + response = self.get("/".join(subpath)) content = _load_atom_entries(response) for entry in content: this_subpath = subpath + [entry.title] # The "all" endpoint doesn't work yet. # The "tcp/ssl" endpoint is not a real input collection. - if entry.title == 'all' or this_subpath == ['tcp', 'ssl']: + if entry.title == "all" or this_subpath == ["tcp", "ssl"]: continue - if 'create' in [x.rel for x in entry.link]: - path = '/'.join(subpath + [entry.title]) + if "create" in [x.rel for x in entry.link]: + path = "/".join(subpath + [entry.title]) kinds.append(path) else: subkinds = self._get_kind_list(subpath + [entry.title]) @@ -2576,10 +2673,10 @@ def kindpath(self, kind): :return: The relative endpoint path. :rtype: ``string`` """ - if kind == 'tcp': - return UrlEncoded('tcp/raw', skip_encode=True) - if kind == 'splunktcp': - return UrlEncoded('tcp/cooked', skip_encode=True) + if kind == "tcp": + return UrlEncoded("tcp/raw", skip_encode=True) + if kind == "splunktcp": + return UrlEncoded("tcp/cooked", skip_encode=True) return UrlEncoded(kind, skip_encode=True) def list(self, *kinds, **kwargs): @@ -2664,7 +2761,7 @@ def list(self, *kinds, **kwargs): entities.append(entity) return entities - search = kwargs.get('search', '*') + search = kwargs.get("search", "*") entities = [] for kind in kinds: @@ -2679,7 +2776,8 @@ def list(self, *kinds, **kwargs): raise entries = _load_atom_entries(response) - if entries is None: continue # No inputs to process + if entries is None: + continue # No inputs to process for entry in entries: state = _parse_atom_entry(entry) # Unquote the URL, since all URL encoded in the SDK @@ -2688,25 +2786,25 @@ def list(self, *kinds, **kwargs): path = parse.unquote(state.links.alternate) entity = Input(self.service, path, kind, state=state) entities.append(entity) - if 'offset' in kwargs: - entities = entities[kwargs['offset']:] - if 'count' in kwargs: - entities = entities[:kwargs['count']] - if kwargs.get('sort_mode', None) == 'alpha': - sort_field = kwargs.get('sort_field', 'name') - if sort_field == 'name': + if "offset" in kwargs: + entities = entities[kwargs["offset"] :] + if "count" in kwargs: + entities = entities[: kwargs["count"]] + if kwargs.get("sort_mode", None) == "alpha": + sort_field = kwargs.get("sort_field", "name") + if sort_field == "name": f = lambda x: x.name.lower() else: f = lambda x: x[sort_field].lower() entities = sorted(entities, key=f) - if kwargs.get('sort_mode', None) == 'alpha_case': - sort_field = kwargs.get('sort_field', 'name') - if sort_field == 'name': + if kwargs.get("sort_mode", None) == "alpha_case": + sort_field = kwargs.get("sort_field", "name") + if sort_field == "name": f = lambda x: x.name else: f = lambda x: x[sort_field] entities = sorted(entities, key=f) - if kwargs.get('sort_dir', 'asc') == 'desc': + if kwargs.get("sort_dir", "asc") == "desc": entities = list(reversed(entities)) return entities @@ -2715,7 +2813,7 @@ def __iter__(self, **kwargs): yield item def iter(self, **kwargs): - """ Iterates over the collection of inputs. + """Iterates over the collection of inputs. :param kwargs: Additional arguments (optional): @@ -2739,7 +2837,7 @@ def iter(self, **kwargs): yield item def oneshot(self, path, **kwargs): - """ Creates a oneshot data input, which is an upload of a single file + """Creates a oneshot data input, which is an upload of a single file for one-time indexing. :param path: The path and filename. @@ -2748,7 +2846,7 @@ def oneshot(self, path, **kwargs): available parameters, see `Input parameters `_ on Splunk Developer Portal. :type kwargs: ``dict`` """ - self.post('oneshot', name=path, **kwargs) + self.post("oneshot", name=path, **kwargs) class Job(Entity): @@ -2815,7 +2913,7 @@ def events(self, **kwargs): :return: The ``InputStream`` IO handle to this job's events. """ - kwargs['segmentation'] = kwargs.get('segmentation', 'none') + kwargs["segmentation"] = kwargs.get("segmentation", "none") # Search API v1(GET) and v2(POST) if self.service.disable_v2_api: @@ -2838,7 +2936,7 @@ def is_done(self): """ if not self.is_ready(): return False - done = (self._state.content['isDone'] == '1') + done = self._state.content["isDone"] == "1" return done def is_ready(self): @@ -2852,7 +2950,7 @@ def is_ready(self): if response.status == 204: return False self._state = self.read(response) - ready = self._state.content['dispatchState'] not in ['QUEUED', 'PARSING'] + ready = self._state.content["dispatchState"] not in ["QUEUED", "PARSING"] return ready @property @@ -2907,7 +3005,7 @@ def results(self, **query_params): :return: The ``InputStream`` IO handle to this job's results. """ - query_params['segmentation'] = query_params.get('segmentation', 'none') + query_params["segmentation"] = query_params.get("segmentation", "none") # Search API v1(GET) and v2(POST) if self.service.disable_v2_api: @@ -2952,7 +3050,7 @@ def preview(self, **query_params): :return: The ``InputStream`` IO handle to this job's preview results. """ - query_params['segmentation'] = query_params.get('segmentation', 'none') + query_params["segmentation"] = query_params.get("segmentation", "none") # Search API v1(GET) and v2(POST) if self.service.disable_v2_api: @@ -2983,7 +3081,7 @@ def set_priority(self, value): :return: The :class:`Job`. """ - self.post('control', action="setpriority", priority=value) + self.post("control", action="setpriority", priority=value) return self def summary(self, **kwargs): @@ -3060,19 +3158,17 @@ def __init__(self, service): def _load_list(self, response): # Overridden because Job takes a sid instead of a path. entries = _load_atom_entries(response) - if entries is None: return [] + if entries is None: + return [] entities = [] for entry in entries: state = _parse_atom_entry(entry) - entity = self.item( - self.service, - entry['content']['sid'], - state=state) + entity = self.item(self.service, entry["content"]["sid"], state=state) entities.append(entity) return entities def create(self, query, **kwargs): - """ Creates a search using a search query and any additional parameters + """Creates a search using a search query and any additional parameters you provide. :param query: The search query. @@ -3086,7 +3182,9 @@ def create(self, query, **kwargs): :return: The :class:`Job`. """ if kwargs.get("exec_mode", None) == "oneshot": - raise TypeError("Cannot specify exec_mode=oneshot; use the oneshot method instead.") + raise TypeError( + "Cannot specify exec_mode=oneshot; use the oneshot method instead." + ) response = self.post(search=query, **kwargs) sid = _load_sid(response, kwargs.get("output_mode", None)) return Job(self.service, sid) @@ -3132,10 +3230,8 @@ def export(self, query, **params): """ if "exec_mode" in params: raise TypeError("Cannot specify an exec_mode to export.") - params['segmentation'] = params.get('segmentation', 'none') - return self.post(path_segment="export", - search=query, - **params).body + params["segmentation"] = params.get("segmentation", "none") + return self.post(path_segment="export", search=query, **params).body def itemmeta(self): """There is no metadata available for class:``Jobs``. @@ -3195,10 +3291,8 @@ def oneshot(self, query, **params): """ if "exec_mode" in params: raise TypeError("Cannot specify an exec_mode to oneshot.") - params['segmentation'] = params.get('segmentation', 'none') - return self.post(search=query, - exec_mode="oneshot", - **params).body + params["segmentation"] = params.get("segmentation", "none") + return self.post(search=query, exec_mode="oneshot", **params).body class Loggers(Collection): @@ -3238,15 +3332,15 @@ class ModularInputKind(Entity): """ def __contains__(self, name): - args = self.state.content['endpoints']['args'] + args = self.state.content["endpoints"]["args"] if name in args: return True return Entity.__contains__(self, name) def __getitem__(self, name): - args = self.state.content['endpoint']['args'] + args = self.state.content["endpoint"]["args"] if name in args: - return args['item'] + return args["item"] return Entity.__getitem__(self, name) @property @@ -3263,11 +3357,13 @@ def arguments(self): :return: A dictionary describing the arguments this modular input kind takes. :rtype: ``dict`` """ - return self.state.content['endpoint']['args'] + return self.state.content["endpoint"]["args"] def update(self, **kwargs): """Raises an error. Modular input kinds are read only.""" - raise IllegalOperationException("Modular input kinds cannot be updated via the REST API.") + raise IllegalOperationException( + "Modular input kinds cannot be updated via the REST API." + ) class SavedSearch(Entity): @@ -3292,7 +3388,7 @@ def alert_count(self): :return: The number of alerts fired by this saved search. :rtype: ``integer`` """ - return int(self._state.content.get('triggered_alert_count', 0)) + return int(self._state.content.get("triggered_alert_count", 0)) def dispatch(self, **kwargs): """Runs the saved search and returns the resulting search job. @@ -3318,15 +3414,20 @@ def fired_alerts(self): :return: A collection of fired alerts. :rtype: :class:`AlertGroup` """ - if self['is_scheduled'] == '0': - raise IllegalOperationException('Unscheduled saved searches have no alerts.') + if self["is_scheduled"] == "0": + raise IllegalOperationException( + "Unscheduled saved searches have no alerts." + ) c = Collection( self.service, - self.service._abspath(PATH_FIRED_ALERTS + self.name, - owner=self._state.access.owner, - app=self._state.access.app, - sharing=self._state.access.sharing), - item=AlertGroup) + self.service._abspath( + PATH_FIRED_ALERTS + self.name, + owner=self._state.access.owner, + app=self._state.access.app, + sharing=self._state.access.sharing, + ), + item=AlertGroup, + ) return c def history(self, **kwargs): @@ -3339,7 +3440,8 @@ def history(self, **kwargs): """ response = self.get("history", **kwargs) entries = _load_atom_entries(response) - if entries is None: return [] + if entries is None: + return [] jobs = [] for entry in entries: job = Job(self.service, entry.title) @@ -3363,11 +3465,12 @@ def update(self, search=None, **kwargs): # Updates to a saved search *require* that the search string be # passed, so we pass the current search string if a value wasn't # provided by the caller. - if search is None: search = self.content.search + if search is None: + search = self.content.search Entity.update(self, search=search, **kwargs) return self - def scheduled_times(self, earliest_time='now', latest_time='+1h'): + def scheduled_times(self, earliest_time="now", latest_time="+1h"): """Returns the times when this search is scheduled to run. By default this method returns the times in the next hour. For different @@ -3382,13 +3485,12 @@ def scheduled_times(self, earliest_time='now', latest_time='+1h'): :return: The list of search times. """ - response = self.get("scheduled_times", - earliest_time=earliest_time, - latest_time=latest_time) + response = self.get( + "scheduled_times", earliest_time=earliest_time, latest_time=latest_time + ) data = self._load_atom_entry(response) rec = _parse_atom_entry(data) - times = [datetime.fromtimestamp(int(t)) - for t in rec.content.scheduled_times] + times = [datetime.fromtimestamp(int(t)) for t in rec.content.scheduled_times] return times def suppress(self, expiration): @@ -3430,11 +3532,10 @@ class SavedSearches(Collection): collection using :meth:`Service.saved_searches`.""" def __init__(self, service): - Collection.__init__( - self, service, PATH_SAVED_SEARCHES, item=SavedSearch) + Collection.__init__(self, service, PATH_SAVED_SEARCHES, item=SavedSearch) def create(self, name, search, **kwargs): - """ Creates a saved search. + """Creates a saved search. :param name: The name for the saved search. :type name: ``string`` @@ -3452,6 +3553,7 @@ def create(self, name, search, **kwargs): class Macro(Entity): """This class represents a search macro.""" + def __init__(self, service, path, **kwargs): Entity.__init__(self, service, path, **kwargs) @@ -3461,7 +3563,7 @@ def args(self): :return: The macro arguments. :rtype: ``string`` """ - return self._state.content.get('args', '') + return self._state.content.get("args", "") @property def definition(self): @@ -3469,7 +3571,7 @@ def definition(self): :return: The macro definition. :rtype: ``string`` """ - return self._state.content.get('definition', '') + return self._state.content.get("definition", "") @property def errormsg(self): @@ -3477,7 +3579,7 @@ def errormsg(self): :return: The validation error message for the macro. :rtype: ``string`` """ - return self._state.content.get('errormsg', '') + return self._state.content.get("errormsg", "") @property def iseval(self): @@ -3485,7 +3587,7 @@ def iseval(self): :return: The iseval value for the macro. :rtype: ``string`` """ - return self._state.content.get('iseval', '0') + return self._state.content.get("iseval", "0") def update(self, definition=None, **kwargs): """Updates the server with any changes you've made to the current macro @@ -3500,7 +3602,8 @@ def update(self, definition=None, **kwargs): # Updates to a macro *require* that the definition be # passed, so we pass the current definition if a value wasn't # provided by the caller. - if definition is None: definition = self.content.definition + if definition is None: + definition = self.content.definition Entity.update(self, definition=definition, **kwargs) return self @@ -3510,18 +3613,18 @@ def validation(self): :return: The validation expression for the macro. :rtype: ``string`` """ - return self._state.content.get('validation', '') + return self._state.content.get("validation", "") class Macros(Collection): """This class represents a collection of macros. Retrieve this collection using :meth:`Service.macros`.""" + def __init__(self, service): - Collection.__init__( - self, service, PATH_MACROS, item=Macro) + Collection.__init__(self, service, PATH_MACROS, item=Macro) def create(self, name, definition, **kwargs): - """ Creates a macro. + """Creates a macro. :param name: The name for the macro. :type name: ``string`` :param definition: The macro definition. @@ -3557,8 +3660,7 @@ def update(self, **kwargs): class User(Entity): - """This class represents a Splunk user. - """ + """This class represents a Splunk user.""" @property def role_entities(self): @@ -3568,7 +3670,11 @@ def role_entities(self): :rtype: ``list`` """ all_role_names = [r.name for r in self.service.roles.list()] - return [self.service.roles[name] for name in self.content.roles if name in all_role_names] + return [ + self.service.roles[name] + for name in self.content.roles + if name in all_role_names + ] # Splunk automatically lowercases new user names so we need to match that @@ -3627,13 +3733,12 @@ def create(self, username, password, roles, **params): entry = _load_atom(response, XNAME_ENTRY).entry state = _parse_atom_entry(entry) entity = self.item( - self.service, - parse.unquote(state.links.alternate), - state=state) + self.service, parse.unquote(state.links.alternate), state=state + ) return entity def delete(self, name): - """ Deletes the user and returns the resulting collection of users. + """Deletes the user and returns the resulting collection of users. :param name: The name of the user to delete. :type name: ``string`` @@ -3645,8 +3750,7 @@ def delete(self, name): class Role(Entity): - """This class represents a user role. - """ + """This class represents a user role.""" def grant(self, *capabilities_to_grant): """Grants additional capabilities to this role. @@ -3668,7 +3772,7 @@ def grant(self, *capabilities_to_grant): for capability in capabilities_to_grant: if capability not in possible_capabilities: raise NoSuchCapability(capability) - new_capabilities = self['capabilities'] + list(capabilities_to_grant) + new_capabilities = self["capabilities"] + list(capabilities_to_grant) self.post(capabilities=new_capabilities) return self @@ -3693,13 +3797,13 @@ def revoke(self, *capabilities_to_revoke): for capability in capabilities_to_revoke: if capability not in possible_capabilities: raise NoSuchCapability(capability) - old_capabilities = self['capabilities'] + old_capabilities = self["capabilities"] new_capabilities = [] for c in old_capabilities: if c not in capabilities_to_revoke: new_capabilities.append(c) if not new_capabilities: - new_capabilities = '' # Empty lists don't get passed in the body, so we have to force an empty argument. + new_capabilities = "" # Empty lists don't get passed in the body, so we have to force an empty argument. self.post(capabilities=new_capabilities) return self @@ -3752,13 +3856,12 @@ def create(self, name, **params): entry = _load_atom(response, XNAME_ENTRY).entry state = _parse_atom_entry(entry) entity = self.item( - self.service, - parse.unquote(state.links.alternate), - state=state) + self.service, parse.unquote(state.links.alternate), state=state + ) return entity def delete(self, name): - """ Deletes the role and returns the resulting collection of roles. + """Deletes the role and returns the resulting collection of roles. :param name: The name of the role to delete. :type name: ``string`` @@ -3777,10 +3880,10 @@ def setupInfo(self): :return: The setup information. """ - return self.content.get('eai:setup', None) + return self.content.get("eai:setup", None) def package(self): - """ Creates a compressed package of the app for archiving.""" + """Creates a compressed package of the app for archiving.""" return self._run_action("package") def updateInfo(self): @@ -3790,7 +3893,9 @@ def updateInfo(self): class KVStoreCollections(Collection): def __init__(self, service): - Collection.__init__(self, service, 'storage/collections/config', item=KVStoreCollection) + Collection.__init__( + self, service, "storage/collections/config", item=KVStoreCollection + ) def __getitem__(self, item): res = Collection.__getitem__(self, item) @@ -3816,9 +3921,9 @@ def create(self, name, accelerated_fields={}, fields={}, **kwargs): for k, v in accelerated_fields.items(): if isinstance(v, dict): v = json.dumps(v) - kwargs['accelerated_fields.' + k] = v + kwargs["accelerated_fields." + k] = v for k, v in fields.items(): - kwargs['field.' + k] = v + kwargs["field." + k] = v return self.post(name=name, **kwargs) @@ -3842,7 +3947,9 @@ def update_accelerated_field(self, name, value): :return: Result of POST request """ kwargs = {} - kwargs['accelerated_fields.' + name] = json.dumps(value) if isinstance(value, dict) else value + kwargs["accelerated_fields." + name] = ( + json.dumps(value) if isinstance(value, dict) else value + ) return self.post(**kwargs) def update_field(self, name, value): @@ -3856,7 +3963,7 @@ def update_field(self, name, value): :return: Result of POST request """ kwargs = {} - kwargs['field.' + name] = value + kwargs["field." + name] = value return self.post(**kwargs) @@ -3865,22 +3972,45 @@ class KVStoreCollectionData: Retrieve using :meth:`KVStoreCollection.data` """ - JSON_HEADER = [('Content-Type', 'application/json')] + + JSON_HEADER = [("Content-Type", "application/json")] def __init__(self, collection): self.service = collection.service self.collection = collection self.owner, self.app, self.sharing = collection._proper_namespace() - self.path = 'storage/collections/data/' + UrlEncoded(self.collection.name, encode_slash=True) + '/' + self.path = ( + "storage/collections/data/" + + UrlEncoded(self.collection.name, encode_slash=True) + + "/" + ) def _get(self, url, **kwargs): - return self.service.get(self.path + url, owner=self.owner, app=self.app, sharing=self.sharing, **kwargs) + return self.service.get( + self.path + url, + owner=self.owner, + app=self.app, + sharing=self.sharing, + **kwargs, + ) def _post(self, url, **kwargs): - return self.service.post(self.path + url, owner=self.owner, app=self.app, sharing=self.sharing, **kwargs) + return self.service.post( + self.path + url, + owner=self.owner, + app=self.app, + sharing=self.sharing, + **kwargs, + ) def _delete(self, url, **kwargs): - return self.service.delete(self.path + url, owner=self.owner, app=self.app, sharing=self.sharing, **kwargs) + return self.service.delete( + self.path + url, + owner=self.owner, + app=self.app, + sharing=self.sharing, + **kwargs, + ) def query(self, **query): """ @@ -3897,7 +4027,7 @@ def query(self, **query): if isinstance(query[key], dict): query[key] = json.dumps(value) - return json.loads(self._get('', **query).body.read().decode('utf-8')) + return json.loads(self._get("", **query).body.read().decode("utf-8")) def query_by_id(self, id): """ @@ -3909,7 +4039,11 @@ def query_by_id(self, id): :return: Document with id :rtype: ``dict`` """ - return json.loads(self._get(UrlEncoded(str(id), encode_slash=True)).body.read().decode('utf-8')) + return json.loads( + self._get(UrlEncoded(str(id), encode_slash=True)) + .body.read() + .decode("utf-8") + ) def insert(self, data): """ @@ -3924,7 +4058,10 @@ def insert(self, data): if isinstance(data, dict): data = json.dumps(data) return json.loads( - self._post('', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + self._post("", headers=KVStoreCollectionData.JSON_HEADER, body=data) + .body.read() + .decode("utf-8") + ) def delete(self, query=None): """ @@ -3935,7 +4072,7 @@ def delete(self, query=None): :return: Result of DELETE request """ - return self._delete('', **({'query': query}) if query else {}) + return self._delete("", **({"query": query}) if query else {}) def delete_by_id(self, id): """ @@ -3962,8 +4099,15 @@ def update(self, id, data): """ if isinstance(data, dict): data = json.dumps(data) - return json.loads(self._post(UrlEncoded(str(id), encode_slash=True), headers=KVStoreCollectionData.JSON_HEADER, - body=data).body.read().decode('utf-8')) + return json.loads( + self._post( + UrlEncoded(str(id), encode_slash=True), + headers=KVStoreCollectionData.JSON_HEADER, + body=data, + ) + .body.read() + .decode("utf-8") + ) def batch_find(self, *dbqueries): """ @@ -3976,12 +4120,17 @@ def batch_find(self, *dbqueries): :rtype: ``array`` of ``array`` """ if len(dbqueries) < 1: - raise Exception('Must have at least one query.') + raise Exception("Must have at least one query.") data = json.dumps(dbqueries) return json.loads( - self._post('batch_find', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + self._post( + "batch_find", headers=KVStoreCollectionData.JSON_HEADER, body=data + ) + .body.read() + .decode("utf-8") + ) def batch_save(self, *documents): """ @@ -3994,9 +4143,14 @@ def batch_save(self, *documents): :rtype: ``dict`` """ if len(documents) < 1: - raise Exception('Must have at least one document.') + raise Exception("Must have at least one document.") data = json.dumps(documents) return json.loads( - self._post('batch_save', headers=KVStoreCollectionData.JSON_HEADER, body=data).body.read().decode('utf-8')) + self._post( + "batch_save", headers=KVStoreCollectionData.JSON_HEADER, body=data + ) + .body.read() + .decode("utf-8") + ) diff --git a/splunklib/data.py b/splunklib/data.py index 34f3ffac1..1f026ed83 100644 --- a/splunklib/data.py +++ b/splunklib/data.py @@ -59,8 +59,8 @@ def hasattrs(element): def localname(xname): - rcurly = xname.find('}') - return xname if rcurly == -1 else xname[rcurly + 1:] + rcurly = xname.find("}") + return xname if rcurly == -1 else xname[rcurly + 1 :] def load(text, match=None): @@ -75,13 +75,12 @@ def load(text, match=None): :param match: A tag name or path to match (optional). :type match: ``string`` """ - if text is None: return None + if text is None: + return None text = text.strip() - if len(text) == 0: return None - nametable = { - 'namespaces': [], - 'names': {} - } + if len(text) == 0: + return None + nametable = {"namespaces": [], "names": {}} root = XML(text) items = [root] if match is None else root.findall(match) @@ -95,7 +94,8 @@ def load(text, match=None): # Load the attributes of the given element. def load_attrs(element): - if not hasattrs(element): return None + if not hasattrs(element): + return None attrs = record() for key, value in element.attrib.items(): attrs[key] = value @@ -118,8 +118,10 @@ def load_elem(element, nametable=None): name = localname(element.tag) attrs = load_attrs(element) value = load_value(element, nametable) - if attrs is None: return name, value - if value is None: return name, attrs + if attrs is None: + return name, value + if value is None: + return name, attrs # If value is simple, merge into attrs dict using special key if isinstance(value, str): attrs["$text"] = value @@ -151,8 +153,10 @@ def load_list(element, nametable=None): # Load the given root element. def load_root(element, nametable=None): tag = element.tag - if isdict(tag): return load_dict(element, nametable) - if islist(tag): return load_list(element, nametable) + if isdict(tag): + return load_dict(element, nametable) + if islist(tag): + return load_list(element, nametable) k, v = load_elem(element, nametable) return Record.fromkv(k, v) @@ -176,8 +180,10 @@ def load_value(element, nametable=None): if count == 1: child = children[0] tag = child.tag - if isdict(tag): return load_dict(child, nametable) - if islist(tag): return load_list(child, nametable) + if isdict(tag): + return load_dict(child, nametable) + if islist(tag): + return load_list(child, nametable) value = record() for child in children: @@ -213,10 +219,12 @@ class Record(dict): one is placed into a nested dictionary, so you can write ``r.bar.qux`` or ``r['bar.qux']`` interchangeably. """ - sep = '.' + + sep = "." def __call__(self, *args): - if len(args) == 0: return self + if len(args) == 0: + return self return Record((key, self[key]) for key in args) def __getattr__(self, name): @@ -245,8 +253,8 @@ def __getitem__(self, key): for k, v in self.items(): if not k.startswith(key): continue - suffix = k[len(key):] - if '.' in suffix: + suffix = k[len(key) :] + if "." in suffix: ks = suffix.split(self.sep) z = result for x in ks[:-1]: @@ -268,5 +276,6 @@ def record(value=None): :param value: An initial record value. :type value: ``dict`` """ - if value is None: value = {} + if value is None: + value = {} return Record(value) diff --git a/splunklib/modularinput/__init__.py b/splunklib/modularinput/__init__.py index ace954a02..987d1f958 100644 --- a/splunklib/modularinput/__init__.py +++ b/splunklib/modularinput/__init__.py @@ -3,6 +3,7 @@ from splunklib.modularinput import * """ + from .argument import Argument from .event import Event from .event_writer import EventWriter diff --git a/splunklib/modularinput/argument.py b/splunklib/modularinput/argument.py index ec6438750..99203ca25 100644 --- a/splunklib/modularinput/argument.py +++ b/splunklib/modularinput/argument.py @@ -14,8 +14,8 @@ import xml.etree.ElementTree as ET -class Argument: +class Argument: """Class representing an argument to a modular input kind. ``Argument`` is meant to be used with ``Scheme`` to generate an XML @@ -45,8 +45,16 @@ class Argument: data_type_number = "NUMBER" data_type_string = "STRING" - def __init__(self, name, description=None, validation=None, - data_type=data_type_string, required_on_edit=False, required_on_create=False, title=None): + def __init__( + self, + name, + description=None, + validation=None, + data_type=data_type_string, + required_on_edit=False, + required_on_create=False, + title=None, + ): """ :param name: ``string``, identifier for this argument in Splunk. :param description: ``string``, human-readable description of the argument. @@ -91,7 +99,7 @@ def add_to_document(self, parent): subelements = [ ("data_type", self.data_type), ("required_on_edit", self.required_on_edit), - ("required_on_create", self.required_on_create) + ("required_on_create", self.required_on_create), ] for name, value in subelements: diff --git a/splunklib/modularinput/event.py b/splunklib/modularinput/event.py index bebd61e46..4d243c753 100644 --- a/splunklib/modularinput/event.py +++ b/splunklib/modularinput/event.py @@ -23,8 +23,19 @@ class Event: To write an input to a stream, call the ``write_to`` function, passing in a stream. """ - def __init__(self, data=None, stanza=None, time=None, host=None, index=None, source=None, - sourcetype=None, done=True, unbroken=True): + + def __init__( + self, + data=None, + stanza=None, + time=None, + host=None, + index=None, + source=None, + sourcetype=None, + done=True, + unbroken=True, + ): """There are no required parameters for constructing an Event **Example with minimal configuration**:: @@ -78,7 +89,9 @@ def write_to(self, stream): :param stream: stream to write XML to. """ if self.data is None: - raise ValueError("Events must have at least the data field set to be written to XML.") + raise ValueError( + "Events must have at least the data field set to be written to XML." + ) event = ET.Element("event") if self.stanza is not None: @@ -95,7 +108,7 @@ def write_to(self, stream): ("sourcetype", self.sourceType), ("index", self.index), ("host", self.host), - ("data", self.data) + ("data", self.data), ] for node, value in subelements: if value is not None: diff --git a/splunklib/modularinput/event_writer.py b/splunklib/modularinput/event_writer.py index 7ea37ca81..51c3cb0fd 100644 --- a/splunklib/modularinput/event_writer.py +++ b/splunklib/modularinput/event_writer.py @@ -33,7 +33,7 @@ class EventWriter: ERROR = "ERROR" FATAL = "FATAL" - def __init__(self, output = sys.stdout, error = sys.stderr): + def __init__(self, output=sys.stdout, error=sys.stderr): """ :param output: Where to write the output; defaults to sys.stdout. :param error: Where to write any errors; defaults to sys.stderr. @@ -76,7 +76,9 @@ def log_exception(self, message, exception=None, severity=None): :param severity: ``string``, severity of message, see severities defined as class constants. Default severity: ERROR """ if exception is not None: - tb_str = traceback.format_exception(type(exception), exception, exception.__traceback__) + tb_str = traceback.format_exception( + type(exception), exception, exception.__traceback__ + ) else: tb_str = traceback.format_exc() diff --git a/splunklib/modularinput/input_definition.py b/splunklib/modularinput/input_definition.py index 190192f7b..9886374ca 100644 --- a/splunklib/modularinput/input_definition.py +++ b/splunklib/modularinput/input_definition.py @@ -15,6 +15,7 @@ import xml.etree.ElementTree as ET from .utils import parse_xml_data + class InputDefinition: """``InputDefinition`` encodes the XML defining inputs that Splunk passes to a modular input script. @@ -24,7 +25,8 @@ class InputDefinition: i = InputDefinition() """ - def __init__ (self): + + def __init__(self): self.metadata = {} self.inputs = {} diff --git a/splunklib/modularinput/scheme.py b/splunklib/modularinput/scheme.py index a3b086826..76b13a631 100644 --- a/splunklib/modularinput/scheme.py +++ b/splunklib/modularinput/scheme.py @@ -66,7 +66,7 @@ def to_xml(self): subelements = [ ("use_external_validation", self.use_external_validation), ("use_single_instance", self.use_single_instance), - ("streaming_mode", self.streaming_mode) + ("streaming_mode", self.streaming_mode), ] for name, value in subelements: ET.SubElement(root, name).text = str(value).lower() diff --git a/splunklib/modularinput/script.py b/splunklib/modularinput/script.py index a6c961931..89a08edc2 100644 --- a/splunklib/modularinput/script.py +++ b/splunklib/modularinput/script.py @@ -75,7 +75,8 @@ def run_script(self, args, event_writer, input_stream): if scheme is None: event_writer.log( EventWriter.FATAL, - "Modular input script returned a null scheme.") + "Modular input script returned a null scheme.", + ) return 1 event_writer.write_xml_document(scheme.to_xml()) return 0 @@ -91,8 +92,10 @@ def run_script(self, args, event_writer, input_stream): event_writer.write_xml_document(root) return 1 - event_writer.log(EventWriter.ERROR, "Invalid arguments to modular input script:" + ' '.join( - args)) + event_writer.log( + EventWriter.ERROR, + "Invalid arguments to modular input script:" + " ".join(args), + ) return 1 except Exception as e: @@ -101,7 +104,7 @@ def run_script(self, args, event_writer, input_stream): @property def service(self): - """ Returns a Splunk service object for this script invocation. + """Returns a Splunk service object for this script invocation. The service object is created from the Splunkd URI and session key passed to the command invocation on the modular input stream. It is diff --git a/splunklib/modularinput/utils.py b/splunklib/modularinput/utils.py index dad73dd07..2218c0d27 100644 --- a/splunklib/modularinput/utils.py +++ b/splunklib/modularinput/utils.py @@ -42,11 +42,16 @@ def xml_compare(expected, found): return False # compare elements, if there is no text node, return True - if (expected.text is None or expected.text.strip() == "") \ - and (found.text is None or found.text.strip() == ""): + if (expected.text is None or expected.text.strip() == "") and ( + found.text is None or found.text.strip() == "" + ): return True - return expected.tag == found.tag and expected.text == found.text \ - and expected.attrib == found.attrib + return ( + expected.tag == found.tag + and expected.text == found.text + and expected.attrib == found.attrib + ) + def parse_parameters(param_node): if param_node.tag == "param": @@ -58,15 +63,14 @@ def parse_parameters(param_node): return parameters raise ValueError(f"Invalid configuration scheme, {param_node.tag} tag unexpected.") + def parse_xml_data(parent_node, child_node_tag): data = {} for child in parent_node: child_name = child.get("name") if child.tag == child_node_tag: if child_node_tag == "stanza": - data[child_name] = { - "__app": child.get("app", None) - } + data[child_name] = {"__app": child.get("app", None)} for param in child: data[child_name][param.get("name")] = parse_parameters(param) elif "item" == parent_node.tag: diff --git a/splunklib/modularinput/validation_definition.py b/splunklib/modularinput/validation_definition.py index b71e1e7c3..c90dc2aae 100644 --- a/splunklib/modularinput/validation_definition.py +++ b/splunklib/modularinput/validation_definition.py @@ -27,6 +27,7 @@ class ValidationDefinition: v = ValidationDefinition() """ + def __init__(self): self.metadata = {} self.parameters = {} diff --git a/splunklib/results.py b/splunklib/results.py index 30476c846..8eed6fe2c 100644 --- a/splunklib/results.py +++ b/splunklib/results.py @@ -40,11 +40,7 @@ from collections import OrderedDict from json import loads as json_loads -__all__ = [ - "ResultsReader", - "Message", - "JSONResultsReader" -] +__all__ = ["ResultsReader", "Message", "JSONResultsReader"] import deprecation @@ -151,7 +147,9 @@ def read(self, n=None): return response -@deprecation.deprecated(details="Use the JSONResultsReader function instead in conjuction with the 'output_mode' query param set to 'json'") +@deprecation.deprecated( + details="Use the JSONResultsReader function instead in conjuction with the 'output_mode' query param set to 'json'" +) class ResultsReader: """This class returns dictionaries and Splunk messages from an XML results stream. @@ -209,37 +207,36 @@ def __iter__(self): def __next__(self): return next(self._gen) - def _parse_results(self, stream): """Parse results and messages out of *stream*.""" result = None values = None try: - for event, elem in et.iterparse(stream, events=('start', 'end')): - if elem.tag == 'results' and event == 'start': + for event, elem in et.iterparse(stream, events=("start", "end")): + if elem.tag == "results" and event == "start": # The wrapper element is a . We # don't care about it except to tell is whether these # are preview results, or the final results from the # search. - is_preview = elem.attrib['preview'] == '1' + is_preview = elem.attrib["preview"] == "1" self.is_preview = is_preview - if elem.tag == 'result': - if event == 'start': + if elem.tag == "result": + if event == "start": result = OrderedDict() - elif event == 'end': + elif event == "end": yield result result = None elem.clear() - elif elem.tag == 'field' and result is not None: + elif elem.tag == "field" and result is not None: # We need the 'result is not None' check because # 'field' is also the element name in the # header that gives field order, which is not what we # want at all. - if event == 'start': + if event == "start": values = [] - elif event == 'end': - field_name = elem.attrib['k'] + elif event == "end": + field_name = elem.attrib["k"] if len(values) == 1: result[field_name] = values[0] else: @@ -251,22 +248,22 @@ def _parse_results(self, stream): # streaming. elem.clear() - elif elem.tag in ('text', 'v') and event == 'end': + elif elem.tag in ("text", "v") and event == "end": text = "".join(elem.itertext()) values.append(text) elem.clear() - elif elem.tag == 'msg': - if event == 'start': - msg_type = elem.attrib['type'] - elif event == 'end': + elif elem.tag == "msg": + if event == "start": + msg_type = elem.attrib["type"] + elif event == "end": text = elem.text if elem.text is not None else "" yield Message(msg_type, text) elem.clear() except SyntaxError as pe: # This is here to handle the same incorrect return from # splunk that is described in __init__. - if 'no element found' in pe.msg: + if "no element found" in pe.msg: return else: raise @@ -327,7 +324,8 @@ def _parse_results(self, stream): text = None for line in stream.readlines(): strip_line = line.strip() - if strip_line.__len__() == 0: continue + if strip_line.__len__() == 0: + continue parsed_line = json_loads(strip_line) if "preview" in parsed_line: self.is_preview = parsed_line["preview"] diff --git a/splunklib/searchcommands/decorators.py b/splunklib/searchcommands/decorators.py index 1393d789a..6d2f7a282 100644 --- a/splunklib/searchcommands/decorators.py +++ b/splunklib/searchcommands/decorators.py @@ -24,7 +24,7 @@ class Configuration: - """ Defines the configuration settings for a search command. + """Defines the configuration settings for a search command. Documents, validates, and ensures that only relevant configuration settings are applied. Adds a :code:`name` class variable to search command classes that don't have one. The :code:`name` is derived from the name of the class. @@ -33,6 +33,7 @@ class Configuration: `__ """ + def __init__(self, o=None, **kwargs): # # The o argument enables the configuration decorator to be used with or without parentheses. For example, it @@ -53,39 +54,40 @@ def __init__(self, o=None, **kwargs): self.settings = kwargs def __call__(self, o): - if isfunction(o): # We must wait to finalize configuration as the class containing this function is under construction # at the time this call to decorate a member function. This will be handled in the call to # o.ConfigurationSettings.fix_up(o) in the elif clause of this code block. o._settings = self.settings elif isclass(o): - # Set command name name = o.__name__ - if name.endswith('Command'): - name = name[:-len('Command')] + if name.endswith("Command"): + name = name[: -len("Command")] o.name = str(name.lower()) # Construct ConfigurationSettings instance for the command class o.ConfigurationSettings = ConfigurationSettingsType( - module=o.__module__ + '.' + o.__name__, - name='ConfigurationSettings', - bases=(o.ConfigurationSettings,)) + module=o.__module__ + "." + o.__name__, + name="ConfigurationSettings", + bases=(o.ConfigurationSettings,), + ) ConfigurationSetting.fix_up(o.ConfigurationSettings, self.settings) o.ConfigurationSettings.fix_up(o) Option.fix_up(o) else: - raise TypeError(f'Incorrect usage: Configuration decorator applied to {type(o)}') + raise TypeError( + f"Incorrect usage: Configuration decorator applied to {type(o)}" + ) return o class ConfigurationSetting(property): - """ Generates a :class:`property` representing the named configuration setting + """Generates a :class:`property` representing the named configuration setting This is a convenience function designed to reduce the amount of boiler-plate code you must write; most notably for property setters. @@ -105,7 +107,17 @@ class ConfigurationSetting(property): :rtype: property """ - def __init__(self, fget=None, fset=None, fdel=None, doc=None, name=None, readonly=None, value=None): + + def __init__( + self, + fget=None, + fset=None, + fdel=None, + doc=None, + name=None, + readonly=None, + value=None, + ): property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) self._readonly = readonly self._value = value @@ -125,23 +137,22 @@ def setter(self, function): @staticmethod def fix_up(cls, values): - - is_configuration_setting = lambda attribute: isinstance(attribute, ConfigurationSetting) + is_configuration_setting = lambda attribute: isinstance( + attribute, ConfigurationSetting + ) definitions = getmembers(cls, is_configuration_setting) i = 0 for name, setting in definitions: - if setting._name is None: setting._name = name = str(name) else: name = setting._name validate, specification = setting._get_specification() - backing_field_name = '_' + name + backing_field_name = "_" + name if setting.fget is None and setting.fset is None and setting.fdel is None: - value = setting._value if setting._readonly or value is not None: @@ -155,14 +166,17 @@ def fget(bfn, value): if not setting._readonly: def fset(bfn, validate, specification, name): - return lambda this, value: setattr(this, bfn, validate(specification, name, value)) + return lambda this, value: setattr( + this, bfn, validate(specification, name, value) + ) - setting = setting.setter(fset(backing_field_name, validate, specification, name)) + setting = setting.setter( + fset(backing_field_name, validate, specification, name) + ) setattr(cls, name, setting) def is_supported_by_protocol(supporting_protocols): - def is_supported_by_protocol(version): return version in supporting_protocols @@ -170,7 +184,9 @@ def is_supported_by_protocol(version): del setting._name, setting._value, setting._readonly - setting.is_supported_by_protocol = is_supported_by_protocol(specification.supporting_protocols) + setting.is_supported_by_protocol = is_supported_by_protocol( + specification.supporting_protocols + ) setting.supporting_protocols = specification.supporting_protocols setting.backing_field_name = backing_field_name definitions[i] = setting @@ -184,15 +200,17 @@ def is_supported_by_protocol(version): continue if setting.fset is None: - raise ValueError(f'The value of configuration setting {name} is fixed') + raise ValueError(f"The value of configuration setting {name} is fixed") setattr(cls, backing_field_name, validate(specification, name, value)) del values[name] if len(values) > 0: settings = sorted(list(values.items())) - settings = [f'{n_v[0]}={n_v[1]}' for n_v in settings] - raise AttributeError('Inapplicable configuration settings: ' + ', '.join(settings)) + settings = [f"{n_v[0]}={n_v[1]}" for n_v in settings] + raise AttributeError( + "Inapplicable configuration settings: " + ", ".join(settings) + ) cls.configuration_setting_definitions = definitions @@ -203,19 +221,20 @@ def _copy_extra_attributes(self, other): return other def _get_specification(self): - name = self._name try: specification = ConfigurationSettingsType.specification_matrix[name] except KeyError: - raise AttributeError(f'Unknown configuration setting: {name}={repr(self._value)}') + raise AttributeError( + f"Unknown configuration setting: {name}={repr(self._value)}" + ) return ConfigurationSettingsType.validate_configuration_setting, specification class Option(property): - """ Represents a search command option. + """Represents a search command option. Required options must be specified on the search command line. @@ -267,7 +286,18 @@ def __init__(self) self._logging_configuration = None """ - def __init__(self, fget=None, fset=None, fdel=None, doc=None, name=None, default=None, require=None, validate=None): + + def __init__( + self, + fget=None, + fset=None, + fdel=None, + doc=None, + name=None, + default=None, + require=None, + validate=None, + ): property.__init__(self, fget, fset, fdel, doc) self.name = name self.default = default @@ -290,21 +320,19 @@ def setter(self, function): @classmethod def fix_up(cls, command_class): - is_option = lambda attribute: isinstance(attribute, Option) definitions = getmembers(command_class, is_option) validate_option_name = OptionName() i = 0 for name, option in definitions: - if option.name is None: option.name = name # no validation required else: validate_option_name(option.name) if option.fget is None and option.fset is None and option.fdel is None: - backing_field_name = '_' + name + backing_field_name = "_" + name def fget(bfn): return lambda this: getattr(this, bfn, None) @@ -344,11 +372,12 @@ def _copy_extra_attributes(self, other): # region Types class Item: - """ Presents an instance/class view over a search command `Option`. + """Presents an instance/class view over a search command `Option`. This class is used by SearchCommand.process to parse and report on option values. """ + def __init__(self, command, option): self._command = command self._option = option @@ -357,12 +386,12 @@ def __init__(self, command, option): self._format = str if validator is None else validator.format def __repr__(self): - return '(' + repr(self.name) + ', ' + repr(self._format(self.value)) + ')' + return "(" + repr(self.name) + ", " + repr(self._format(self.value)) + ")" def __str__(self): value = self.value - value = 'None' if value is None else json_encode_string(self._format(value)) - return self.name + '=' + value + value = "None" if value is None else json_encode_string(self._format(value)) + return self.name + "=" + value # region Properties @@ -372,9 +401,7 @@ def is_required(self): @property def is_set(self): - """ Indicates whether an option value was provided as argument. - - """ + """Indicates whether an option value was provided as argument.""" return self._is_set @property @@ -405,28 +432,43 @@ def reset(self): # endregion class View(OrderedDict): - """ Presents an ordered dictionary view of the set of :class:`Option` arguments to a search command. + """Presents an ordered dictionary view of the set of :class:`Option` arguments to a search command. This class is used by SearchCommand.process to parse and report on option values. """ + def __init__(self, command): definitions = type(command).option_definitions item_class = Option.Item - OrderedDict.__init__(self, ((option.name, item_class(command, option)) for (name, option) in definitions)) + OrderedDict.__init__( + self, + ( + (option.name, item_class(command, option)) + for (name, option) in definitions + ), + ) def __repr__(self): - text = 'Option.View([' + ','.join([repr(item) for item in self.values()]) + '])' + text = ( + "Option.View([" + + ",".join([repr(item) for item in self.values()]) + + "])" + ) return text def __str__(self): - text = ' '.join([str(item) for item in self.values() if item.is_set]) + text = " ".join([str(item) for item in self.values() if item.is_set]) return text # region Methods def get_missing(self): - missing = [item.name for item in self.values() if item.is_required and not item.is_set] + missing = [ + item.name + for item in self.values() + if item.is_required and not item.is_set + ] return missing if len(missing) > 0 else None def reset(self): @@ -435,8 +477,7 @@ def reset(self): # endregion - # endregion -__all__ = ['Configuration', 'Option'] +__all__ = ["Configuration", "Option"] diff --git a/splunklib/searchcommands/environment.py b/splunklib/searchcommands/environment.py index 35f1deaf3..7f8cb6d3f 100644 --- a/splunklib/searchcommands/environment.py +++ b/splunklib/searchcommands/environment.py @@ -15,16 +15,14 @@ # under the License. - from logging import getLogger, root, StreamHandler from logging.config import fileConfig from os import chdir, environ, path, getcwd import sys - def configure_logging(logger_name, filename=None): - """ Configure logging and return the named logger and the location of the logging configuration file loaded. + """Configure logging and return the named logger and the location of the logging configuration file loaded. This function expects a Splunk app directory structure:: @@ -66,13 +64,17 @@ def configure_logging(logger_name, filename=None): """ if filename is None: if logger_name is None: - probing_paths = [path.join('local', 'logging.conf'), path.join('default', 'logging.conf')] + probing_paths = [ + path.join("local", "logging.conf"), + path.join("default", "logging.conf"), + ] else: probing_paths = [ - path.join('local', logger_name + '.logging.conf'), - path.join('default', logger_name + '.logging.conf'), - path.join('local', 'logging.conf'), - path.join('default', 'logging.conf')] + path.join("local", logger_name + ".logging.conf"), + path.join("default", logger_name + ".logging.conf"), + path.join("local", "logging.conf"), + path.join("default", "logging.conf"), + ] for relative_path in probing_paths: configuration_file = path.join(app_root, relative_path) if path.exists(configuration_file): @@ -80,14 +82,16 @@ def configure_logging(logger_name, filename=None): break elif not path.isabs(filename): found = False - for conf in 'local', 'default': + for conf in "local", "default": configuration_file = path.join(app_root, conf, filename) if path.exists(configuration_file): filename = configuration_file found = True break if not found: - raise ValueError(f'Logging configuration file "{filename}" not found in local or default directory') + raise ValueError( + f'Logging configuration file "{filename}" not found in local or default directory' + ) elif not path.exists(filename): raise ValueError(f'Logging configuration file "{filename}" not found') @@ -99,7 +103,7 @@ def configure_logging(logger_name, filename=None): working_directory = getcwd() chdir(app_root) try: - fileConfig(filename, {'SPLUNK_HOME': splunk_home}) + fileConfig(filename, {"SPLUNK_HOME": splunk_home}) finally: chdir(working_directory) _current_logging_configuration_file = filename @@ -112,11 +116,17 @@ def configure_logging(logger_name, filename=None): _current_logging_configuration_file = None -splunk_home = path.abspath(path.join(getcwd(), environ.get('SPLUNK_HOME', ''))) -app_file = getattr(sys.modules['__main__'], '__file__', sys.executable) +splunk_home = path.abspath(path.join(getcwd(), environ.get("SPLUNK_HOME", ""))) +app_file = getattr(sys.modules["__main__"], "__file__", sys.executable) app_root = path.dirname(path.abspath(path.dirname(app_file))) -splunklib_logger, logging_configuration = configure_logging('splunklib') +splunklib_logger, logging_configuration = configure_logging("splunklib") -__all__ = ['app_file', 'app_root', 'logging_configuration', 'splunk_home', 'splunklib_logger'] +__all__ = [ + "app_file", + "app_root", + "logging_configuration", + "splunk_home", + "splunklib_logger", +] diff --git a/splunklib/searchcommands/eventing_command.py b/splunklib/searchcommands/eventing_command.py index d42d056df..d9f90b780 100644 --- a/splunklib/searchcommands/eventing_command.py +++ b/splunklib/searchcommands/eventing_command.py @@ -15,13 +15,12 @@ # under the License. - from .decorators import ConfigurationSetting from .search_command import SearchCommand class EventingCommand(SearchCommand): - """ Applies a transformation to search results as they travel through the events pipeline. + """Applies a transformation to search results as they travel through the events pipeline. Eventing commands typically filter, group, order, and/or or augment event records. Examples of eventing commands from Splunk's built-in command set include sort_, dedup_, and cluster_. Each execution of an eventing command @@ -38,15 +37,16 @@ class EventingCommand(SearchCommand): Splunk 6.3 or later. """ + # region Methods def transform(self, records): - """ Generator function that processes and yields event records to the Splunk events pipeline. + """Generator function that processes and yields event records to the Splunk events pipeline. You must override this method. """ - raise NotImplementedError('EventingCommand.transform(self, records)') + raise NotImplementedError("EventingCommand.transform(self, records)") def _execute(self, ifile, process): SearchCommand._execute(self, ifile, self.transform) @@ -54,12 +54,12 @@ def _execute(self, ifile, process): # endregion class ConfigurationSettings(SearchCommand.ConfigurationSettings): - """ Represents the configuration settings that apply to a :class:`EventingCommand`. + """Represents the configuration settings that apply to a :class:`EventingCommand`.""" - """ # region SCP v1/v2 properties - required_fields = ConfigurationSetting(doc=''' + required_fields = ConfigurationSetting( + doc=""" List of required fields for this search which back-propagates to the generating search. Setting this value enables selected fields mode under SCP 2. Under SCP 1 you must also specify @@ -68,13 +68,15 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Default: :const:`None`, which implicitly selects all fields. - ''') + """ + ) # endregion # region SCP v1 properties - clear_required_fields = ConfigurationSetting(doc=''' + clear_required_fields = ConfigurationSetting( + doc=""" :const:`True`, if required_fields represent the *only* fields required. If :const:`False`, required_fields are additive to any fields that may be required by subsequent commands. @@ -82,22 +84,28 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Default: :const:`False` - ''') + """ + ) - retainsevents = ConfigurationSetting(readonly=True, value=True, doc=''' + retainsevents = ConfigurationSetting( + readonly=True, + value=True, + doc=""" :const:`True`, if the command retains events the way the sort/dedup/cluster commands do. If :const:`False`, the command transforms events the way the stats command does. Fixed: :const:`True` - ''') + """, + ) # endregion # region SCP v2 properties - maxinputs = ConfigurationSetting(doc=''' + maxinputs = ConfigurationSetting( + doc=""" Specifies the maximum number of events that can be passed to the command for each invocation. This limit cannot exceed the value of `maxresultrows` as defined in limits.conf_. Under SCP 1 you must @@ -109,16 +117,21 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): .. _limits.conf: http://docs.splunk.com/Documentation/Splunk/latest/admin/Limitsconf - ''') + """ + ) - type = ConfigurationSetting(readonly=True, value='events', doc=''' + type = ConfigurationSetting( + readonly=True, + value="events", + doc=""" Command type Fixed: :const:`'events'`. Supported by: SCP 2 - ''') + """, + ) # endregion @@ -126,18 +139,19 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod def fix_up(cls, command): - """ Verifies :code:`command` class structure. - - """ + """Verifies :code:`command` class structure.""" if command.transform == EventingCommand.transform: - raise AttributeError('No EventingCommand.transform override') + raise AttributeError("No EventingCommand.transform override") SearchCommand.ConfigurationSettings.fix_up(command) # TODO: Stop looking like a dictionary because we don't obey the semantics # N.B.: Does not use Python 2 dict copy semantics def iteritems(self): iteritems = SearchCommand.ConfigurationSettings.iteritems(self) - return [(name_value[0], 'events' if name_value[0] == 'type' else name_value[1]) for name_value in iteritems] + return [ + (name_value[0], "events" if name_value[0] == "type" else name_value[1]) + for name_value in iteritems + ] # N.B.: Does not use Python 3 dict view semantics diff --git a/splunklib/searchcommands/external_search_command.py b/splunklib/searchcommands/external_search_command.py index a8929f8d3..cceeb5083 100644 --- a/splunklib/searchcommands/external_search_command.py +++ b/splunklib/searchcommands/external_search_command.py @@ -21,21 +21,19 @@ from . import splunklib_logger as logger -if sys.platform == 'win32': +if sys.platform == "win32": from signal import signal, CTRL_BREAK_EVENT, SIGBREAK, SIGINT, SIGTERM from subprocess import Popen import atexit - # P1 [ ] TODO: Add ExternalSearchCommand class documentation class ExternalSearchCommand: def __init__(self, path, argv=None, environ=None): - - if not isinstance(path, (bytes,str)): - raise ValueError(f'Expected a string value for path, not {repr(path)}') + if not isinstance(path, (bytes, str)): + raise ValueError(f"Expected a string value for path, not {repr(path)}") self._logger = getLogger(self.__class__.__name__) self._path = str(path) @@ -49,22 +47,26 @@ def __init__(self, path, argv=None, environ=None): @property def argv(self): - return getattr(self, '_argv') + return getattr(self, "_argv") @argv.setter def argv(self, value): if not (value is None or isinstance(value, (list, tuple))): - raise ValueError(f'Expected a list, tuple or value of None for argv, not {repr(value)}') + raise ValueError( + f"Expected a list, tuple or value of None for argv, not {repr(value)}" + ) self._argv = value @property def environ(self): - return getattr(self, '_environ') + return getattr(self, "_environ") @environ.setter def environ(self, value): if not (value is None or isinstance(value, dict)): - raise ValueError(f'Expected a dictionary value for environ, not {repr(value)}') + raise ValueError( + f"Expected a dictionary value for environ, not {repr(value)}" + ) self._environ = value @property @@ -87,15 +89,17 @@ def execute(self): self._execute(self._path, self._argv, self._environ) except: error_type, error, tb = sys.exc_info() - message = f'Command execution failed: {str(error)}' - self._logger.error(message + '\nTraceback:\n' + ''.join(traceback.format_tb(tb))) + message = f"Command execution failed: {str(error)}" + self._logger.error( + message + "\nTraceback:\n" + "".join(traceback.format_tb(tb)) + ) sys.exit(1) - if sys.platform == 'win32': + if sys.platform == "win32": @staticmethod def _execute(path, argv=None, environ=None): - """ Executes an external search command. + """Executes an external search command. :param path: Path to the external search command. :type path: unicode @@ -113,40 +117,62 @@ def _execute(path, argv=None, environ=None): :return: None """ - search_path = os.getenv('PATH') if environ is None else environ.get('PATH') + search_path = os.getenv("PATH") if environ is None else environ.get("PATH") found = ExternalSearchCommand._search_path(path, search_path) if found is None: - raise ValueError(f'Cannot find command on path: {path}') + raise ValueError(f"Cannot find command on path: {path}") path = found logger.debug(f'starting command="{path}", arguments={argv}') def terminate(signal_number): - sys.exit(f'External search command is terminating on receipt of signal={signal_number}.') + sys.exit( + f"External search command is terminating on receipt of signal={signal_number}." + ) def terminate_child(): if p.pid is not None and p.returncode is None: - logger.debug('terminating command="%s", arguments=%d, pid=%d', path, argv, p.pid) + logger.debug( + 'terminating command="%s", arguments=%d, pid=%d', + path, + argv, + p.pid, + ) os.kill(p.pid, CTRL_BREAK_EVENT) - p = Popen(argv, executable=path, env=environ, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr) + p = Popen( + argv, + executable=path, + env=environ, + stdin=sys.stdin, + stdout=sys.stdout, + stderr=sys.stderr, + ) atexit.register(terminate_child) signal(SIGBREAK, terminate) signal(SIGINT, terminate) signal(SIGTERM, terminate) - logger.debug('started command="%s", arguments=%s, pid=%d', path, argv, p.pid) + logger.debug( + 'started command="%s", arguments=%s, pid=%d', path, argv, p.pid + ) p.wait() - logger.debug('finished command="%s", arguments=%s, pid=%d, returncode=%d', path, argv, p.pid, p.returncode) + logger.debug( + 'finished command="%s", arguments=%s, pid=%d, returncode=%d', + path, + argv, + p.pid, + p.returncode, + ) if p.returncode != 0: sys.exit(p.returncode) @staticmethod def _search_path(executable, paths): - """ Locates an executable program file. + """Locates an executable program file. :param executable: The name of the executable program to locate. :type executable: unicode @@ -174,7 +200,9 @@ def _search_path(executable, paths): if not paths: return None - directories = [directory for directory in paths.split(';') if len(directory)] + directories = [ + directory for directory in paths.split(";") if len(directory) + ] if len(directories) == 0: return None @@ -195,8 +223,9 @@ def _search_path(executable, paths): return None - _executable_extensions = ('.COM', '.EXE') + _executable_extensions = (".COM", ".EXE") else: + @staticmethod def _execute(path, argv, environ): if environ is None: diff --git a/splunklib/searchcommands/generating_command.py b/splunklib/searchcommands/generating_command.py index 36b014c3e..d2d129316 100644 --- a/splunklib/searchcommands/generating_command.py +++ b/splunklib/searchcommands/generating_command.py @@ -24,7 +24,7 @@ class GeneratingCommand(SearchCommand): - """ Generates events based on command arguments. + """Generates events based on command arguments. Generating commands receive no input and must be the first command on a pipeline. There are three pipelines: streams, events, and reports. The streams pipeline generates or processes time-ordered event records on an @@ -182,18 +182,19 @@ class SomeCommand(GeneratingCommand) streaming = false """ + # region Methods def generate(self): - """ A generator that yields records to the Splunk processing pipeline + """A generator that yields records to the Splunk processing pipeline You must override this method. """ - raise NotImplementedError('GeneratingCommand.generate(self)') + raise NotImplementedError("GeneratingCommand.generate(self)") def _execute(self, ifile, process): - """ Execution loop + """Execution loop :param ifile: Input file object. Unused. :type ifile: file @@ -225,8 +226,10 @@ def _execute_chunk_v2(self, process, chunk): else: self._finished = True - def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True): - """ Process data. + def process( + self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True + ): + """Process data. :param argv: Command line arguments. :type argv: list or tuple @@ -250,20 +253,26 @@ def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_ # so ensure that allow_empty_input is always True if not allow_empty_input: - raise ValueError("allow_empty_input cannot be False for Generating Commands") - return super().process(argv=argv, ifile=ifile, ofile=ofile, allow_empty_input=True) + raise ValueError( + "allow_empty_input cannot be False for Generating Commands" + ) + return super().process( + argv=argv, ifile=ifile, ofile=ofile, allow_empty_input=True + ) # endregion # region Types class ConfigurationSettings(SearchCommand.ConfigurationSettings): - """ Represents the configuration settings for a :code:`GeneratingCommand` class. + """Represents the configuration settings for a :code:`GeneratingCommand` class.""" - """ # region SCP v1/v2 Properties - generating = ConfigurationSetting(readonly=True, value=True, doc=''' + generating = ConfigurationSetting( + readonly=True, + value=True, + doc=""" Tells Splunk that this command generates events, but does not process inputs. Generating commands must appear at the front of the search pipeline identified by :meth:`type`. @@ -272,31 +281,37 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1, SCP 2 - ''') + """, + ) # endregion # region SCP v1 Properties - generates_timeorder = ConfigurationSetting(doc=''' + generates_timeorder = ConfigurationSetting( + doc=""" :const:`True`, if the command generates new events. Default: :const:`False` Supported by: SCP 1 - ''') + """ + ) - local = ConfigurationSetting(doc=''' + local = ConfigurationSetting( + doc=""" :const:`True`, if the command should run locally on the search head. Default: :const:`False` Supported by: SCP 1 - ''') + """ + ) - retainsevents = ConfigurationSetting(doc=''' + retainsevents = ConfigurationSetting( + doc=""" :const:`True`, if the command retains events the way the sort, dedup, and cluster commands do, or whether it transforms them the way the stats command does. @@ -304,22 +319,27 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1 - ''') + """ + ) - streaming = ConfigurationSetting(doc=''' + streaming = ConfigurationSetting( + doc=""" :const:`True`, if the command is streamable. Default: :const:`True` Supported by: SCP 1 - ''') + """ + ) # endregion # region SCP v2 Properties - distributed = ConfigurationSetting(value=False, doc=''' + distributed = ConfigurationSetting( + value=False, + doc=""" True, if this command should be distributed to indexers. This value is ignored unless :meth:`type` is equal to :const:`streaming`. It is only this command type that @@ -329,9 +349,12 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 2 - ''') + """, + ) - type = ConfigurationSetting(value='streaming', doc=''' + type = ConfigurationSetting( + value="streaming", + doc=""" A command type name. ==================== ====================================================================================== @@ -346,7 +369,8 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 2 - ''') + """, + ) # endregion @@ -354,11 +378,9 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod def fix_up(cls, command): - """ Verifies :code:`command` class structure. - - """ + """Verifies :code:`command` class structure.""" if command.generate == GeneratingCommand.generate: - raise AttributeError('No GeneratingCommand.generate override') + raise AttributeError("No GeneratingCommand.generate override") # TODO: Stop looking like a dictionary because we don't obey the semantics # N.B.: Does not use Python 2 dict copy semantics @@ -366,9 +388,18 @@ def iteritems(self): iteritems = SearchCommand.ConfigurationSettings.iteritems(self) version = self.command.protocol_version if version == 2: - iteritems = [name_value1 for name_value1 in iteritems if name_value1[0] != 'distributed'] - if not self.distributed and self.type == 'streaming': - iteritems = [(name_value[0], 'stateful') if name_value[0] == 'type' else (name_value[0], name_value[1]) for name_value in iteritems] + iteritems = [ + name_value1 + for name_value1 in iteritems + if name_value1[0] != "distributed" + ] + if not self.distributed and self.type == "streaming": + iteritems = [ + (name_value[0], "stateful") + if name_value[0] == "type" + else (name_value[0], name_value[1]) + for name_value in iteritems + ] return iteritems # N.B.: Does not use Python 3 dict view semantics diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index 5f20c3faf..40b9107c9 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -29,15 +29,15 @@ from json.encoder import encode_basestring_ascii as json_encode_string - - from . import environment -csv.field_size_limit(10485760) # The default value is 128KB; upping to 10MB. See SPL-12117 for background on this issue +csv.field_size_limit( + 10485760 +) # The default value is 128KB; upping to 10MB. See SPL-12117 for background on this issue def set_binary_mode(fh): - """ Helper method to set up binary mode for file handles. + """Helper method to set up binary mode for file handles. Emphasis being sys.stdin, sys.stdout, sys.stderr. For python3, we want to return .buffer """ @@ -47,13 +47,13 @@ def set_binary_mode(fh): return fh # check for buffer - if hasattr(fh, 'buffer'): + if hasattr(fh, "buffer"): return fh.buffer return fh class CommandLineParser: - r""" Parses the arguments to a search command. + r"""Parses the arguments to a search command. A search command line is described by the following syntax. @@ -86,9 +86,10 @@ class CommandLineParser: setting the built-in `log_level` immediately changes the `log_level`. """ + @classmethod def parse(cls, command, argv): - """ Splits an argument list into an options dictionary and a fieldname + """Splits an argument list into an options dictionary and a fieldname list. The argument list, `argv`, must be of the form:: @@ -117,23 +118,24 @@ def parse(cls, command, argv): # Prepare - debug('Parsing %s command line: %r', command_class, argv) + debug("Parsing %s command line: %r", command_class, argv) command.fieldnames = None command.options.reset() - argv = ' '.join(argv) + argv = " ".join(argv) command_args = cls._arguments_re.match(argv) if command_args is None: - raise SyntaxError(f'Syntax error: {argv}') + raise SyntaxError(f"Syntax error: {argv}") # Parse options - for option in cls._options_re.finditer(command_args.group('options')): - name, value = option.group('name'), option.group('value') + for option in cls._options_re.finditer(command_args.group("options")): + name, value = option.group("name"), option.group("value") if name not in command.options: raise ValueError( - f'Unrecognized {command.name} command option: {name}={json_encode_string(value)}') + f"Unrecognized {command.name} command option: {name}={json_encode_string(value)}" + ) command.options[name].value = cls.unquote(value) missing = command.options.get_missing() @@ -141,23 +143,29 @@ def parse(cls, command, argv): if missing is not None: if len(missing) > 1: raise ValueError( - f'Values for these {command.name} command options are required: {", ".join(missing)}') - raise ValueError(f'A value for {command.name} command option {missing[0]} is required') + f"Values for these {command.name} command options are required: {', '.join(missing)}" + ) + raise ValueError( + f"A value for {command.name} command option {missing[0]} is required" + ) # Parse field names - fieldnames = command_args.group('fieldnames') + fieldnames = command_args.group("fieldnames") if fieldnames is None: command.fieldnames = [] else: - command.fieldnames = [cls.unquote(value.group(0)) for value in cls._fieldnames_re.finditer(fieldnames)] + command.fieldnames = [ + cls.unquote(value.group(0)) + for value in cls._fieldnames_re.finditer(fieldnames) + ] - debug(' %s: %s', command_class, command) + debug(" %s: %s", command_class, command) @classmethod def unquote(cls, string): - """ Removes quotes from a quoted string. + """Removes quotes from a quoted string. Splunk search command quote rules are applied. The enclosing double-quotes, if present, are removed. Escaped double-quotes ('\"' or '""') are replaced by a single double-quote ('"'). @@ -170,22 +178,22 @@ def unquote(cls, string): """ if len(string) == 0: - return '' + return "" if string[0] == '"': if len(string) == 1 or string[-1] != '"': - raise SyntaxError('Poorly formed string literal: ' + string) + raise SyntaxError("Poorly formed string literal: " + string) string = string[1:-1] if len(string) == 0: - return '' + return "" def replace(match): value = match.group(0) if value == '""': return '"' if len(value) < 2: - raise SyntaxError('Poorly formed string literal: ' + string) + raise SyntaxError("Poorly formed string literal: " + string) return value[1] result = re.sub(cls._escaped_character_re, replace, string) @@ -193,7 +201,8 @@ def replace(match): # region Class variables - _arguments_re = re.compile(r""" + _arguments_re = re.compile( + r""" ^\s* (?P # Match a leading set of name/value pairs (?: @@ -207,24 +216,29 @@ def replace(match): (?:"(?:\\.|""|[^"])*"|(?:\\.|[^\s"])+)\s* )* )\s*$ - """, re.VERBOSE | re.UNICODE) + """, + re.VERBOSE | re.UNICODE, + ) _escaped_character_re = re.compile(r'(\\.|""|[\\"])') _fieldnames_re = re.compile(r"""("(?:\\.|""|[^"\\])+"|(?:\\.|[^\s"])+)""") - _options_re = re.compile(r""" + _options_re = re.compile( + r""" # Captures a set of name/value pairs when used with re.finditer (?P(?:(?=\w)[^\d]\w*)) # name \s*=\s* # = (?P"(?:\\.|""|[^"])*"|(?:\\.|[^\s"])+) # value - """, re.VERBOSE | re.UNICODE) + """, + re.VERBOSE | re.UNICODE, + ) # endregion class ConfigurationSettingsType(type): - """ Metaclass for constructing ConfigurationSettings classes. + """Metaclass for constructing ConfigurationSettings classes. Instances of :class:`ConfigurationSettingsType` construct :class:`ConfigurationSettings` classes from classes from a base :class:`ConfigurationSettings` class and a dictionary of configuration settings. The settings in the @@ -243,12 +257,12 @@ class ConfigurationSettingsType(type): Adds a ConfigurationSettings attribute to a :meth:`ReportingCommand.map` method, if there is one. """ + def __new__(mcs, module, name, bases): mcs = super(ConfigurationSettingsType, mcs).__new__(mcs, str(name), bases, {}) return mcs def __init__(cls, module, name, bases): - super(ConfigurationSettingsType, cls).__init__(name, bases, None) cls.__module__ = module @@ -258,101 +272,88 @@ def validate_configuration_setting(specification, name, value): if isinstance(specification.type, type): type_names = specification.type.__name__ else: - type_names = ', '.join(map(lambda t: t.__name__, specification.type)) - raise ValueError(f'Expected {type_names} value, not {name}={repr(value)}') + type_names = ", ".join(map(lambda t: t.__name__, specification.type)) + raise ValueError(f"Expected {type_names} value, not {name}={repr(value)}") if specification.constraint and not specification.constraint(value): - raise ValueError(f'Illegal value: {name}={ repr(value)}') + raise ValueError(f"Illegal value: {name}={repr(value)}") return value specification = namedtuple( - 'ConfigurationSettingSpecification', ( - 'type', - 'constraint', - 'supporting_protocols')) + "ConfigurationSettingSpecification", + ("type", "constraint", "supporting_protocols"), + ) # P1 [ ] TODO: Review ConfigurationSettingsType.specification_matrix for completeness and correctness specification_matrix = { - 'clear_required_fields': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'distributed': specification( - type=bool, - constraint=None, - supporting_protocols=[2]), - 'generates_timeorder': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'generating': specification( - type=bool, - constraint=None, - supporting_protocols=[1, 2]), - 'local': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'maxinputs': specification( + "clear_required_fields": specification( + type=bool, constraint=None, supporting_protocols=[1] + ), + "distributed": specification( + type=bool, constraint=None, supporting_protocols=[2] + ), + "generates_timeorder": specification( + type=bool, constraint=None, supporting_protocols=[1] + ), + "generating": specification( + type=bool, constraint=None, supporting_protocols=[1, 2] + ), + "local": specification(type=bool, constraint=None, supporting_protocols=[1]), + "maxinputs": specification( type=int, constraint=lambda value: 0 <= value <= sys.maxsize, - supporting_protocols=[2]), - 'overrides_timeorder': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'required_fields': specification( - type=(list, set, tuple), - constraint=None, - supporting_protocols=[1, 2]), - 'requires_preop': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'retainsevents': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'run_in_preview': specification( - type=bool, - constraint=None, - supporting_protocols=[2]), - 'streaming': specification( - type=bool, - constraint=None, - supporting_protocols=[1]), - 'streaming_preop': specification( - type=(bytes, str), - constraint=None, - supporting_protocols=[1, 2]), - 'type': specification( + supporting_protocols=[2], + ), + "overrides_timeorder": specification( + type=bool, constraint=None, supporting_protocols=[1] + ), + "required_fields": specification( + type=(list, set, tuple), constraint=None, supporting_protocols=[1, 2] + ), + "requires_preop": specification( + type=bool, constraint=None, supporting_protocols=[1] + ), + "retainsevents": specification( + type=bool, constraint=None, supporting_protocols=[1] + ), + "run_in_preview": specification( + type=bool, constraint=None, supporting_protocols=[2] + ), + "streaming": specification( + type=bool, constraint=None, supporting_protocols=[1] + ), + "streaming_preop": specification( + type=(bytes, str), constraint=None, supporting_protocols=[1, 2] + ), + "type": specification( type=(bytes, str), - constraint=lambda value: value in ('events', 'reporting', 'streaming'), - supporting_protocols=[2])} + constraint=lambda value: value in ("events", "reporting", "streaming"), + supporting_protocols=[2], + ), + } class CsvDialect(csv.Dialect): - """ Describes the properties of Splunk CSV streams """ - delimiter = ',' + """Describes the properties of Splunk CSV streams""" + + delimiter = "," quotechar = '"' doublequote = True skipinitialspace = False - lineterminator = '\r\n' - if sys.version_info >= (3, 0) and sys.platform == 'win32': - lineterminator = '\n' + lineterminator = "\r\n" + if sys.version_info >= (3, 0) and sys.platform == "win32": + lineterminator = "\n" quoting = csv.QUOTE_MINIMAL class InputHeader(dict): - """ Represents a Splunk input header as a collection of name/value pairs. - - """ + """Represents a Splunk input header as a collection of name/value pairs.""" def __str__(self): - return '\n'.join([name + ':' + value for name, value in self.items()]) + return "\n".join([name + ":" + value for name, value in self.items()]) def read(self, ifile): - """ Reads an input header from an input file. + """Reads an input header from an input file. The input header is read as a sequence of ****:**** pairs separated by a newline. The end of the input header is signalled by an empty line or an end-of-file. @@ -363,9 +364,9 @@ def read(self, ifile): name, value = None, None for line in ifile: - if line == '\n': + if line == "\n": break - item = line.split(':', 1) + item = line.split(":", 1) if len(item) == 2: # start of a new item if name is not None: @@ -376,20 +377,18 @@ def read(self, ifile): value += urllib.parse.unquote(line) if name is not None: - self[name] = value[:-1] if value[-1] == '\n' else value + self[name] = value[:-1] if value[-1] == "\n" else value -Message = namedtuple('Message', ('type', 'text')) +Message = namedtuple("Message", ("type", "text")) class MetadataDecoder(JSONDecoder): - def __init__(self): JSONDecoder.__init__(self, object_hook=self._object_hook) @staticmethod def _object_hook(dictionary): - object_view = ObjectView(dictionary) stack = deque() stack.append((None, None, dictionary)) @@ -408,18 +407,16 @@ def _object_hook(dictionary): class MetadataEncoder(JSONEncoder): - def __init__(self): JSONEncoder.__init__(self, separators=MetadataEncoder._separators) def default(self, o): return o.__dict__ if isinstance(o, ObjectView) else JSONEncoder.default(self, o) - _separators = (',', ':') + _separators = (",", ":") class ObjectView: - def __init__(self, dictionary): self.__dict__ = dictionary @@ -434,9 +431,8 @@ def __str__(self): class Recorder: - def __init__(self, path, f): - self._recording = gzip.open(path + '.gz', 'wb') + self._recording = gzip.open(path + ".gz", "wb") self._file = f def __getattr__(self, name): @@ -472,7 +468,6 @@ def write(self, text): class RecordWriter: - def __init__(self, ofile, maxresultrows=None): self._maxresultrows = 50000 if maxresultrows is None else maxresultrows @@ -515,7 +510,7 @@ def pending_record_count(self): def _record_count(self): warnings.warn( "_record_count will be deprecated soon. Use pending_record_count instead.", - PendingDeprecationWarning + PendingDeprecationWarning, ) return self.pending_record_count @@ -527,14 +522,14 @@ def committed_record_count(self): def _total_record_count(self): warnings.warn( "_total_record_count will be deprecated soon. Use committed_record_count instead.", - PendingDeprecationWarning + PendingDeprecationWarning, ) return self.committed_record_count def write(self, data): bytes_type = bytes if sys.version_info >= (3, 0) else str if not isinstance(data, bytes_type): - data = data.encode('utf-8') + data = data.encode("utf-8") self.ofile.write(data) def flush(self, finished=None, partial=None): @@ -546,7 +541,9 @@ def flush(self, finished=None, partial=None): def write_message(self, message_type, message_text, *args, **kwargs): self._ensure_validity() - self._inspector.setdefault('messages', []).append((message_type, message_text.format(*args, **kwargs))) + self._inspector.setdefault("messages", []).append( + (message_type, message_text.format(*args, **kwargs)) + ) def write_record(self, record): self._ensure_validity() @@ -568,16 +565,17 @@ def _clear(self): def _ensure_validity(self): if self._finished is True: assert self._record_count == 0 and len(self._inspector) == 0 - raise RuntimeError('I/O operation on closed record writer') + raise RuntimeError("I/O operation on closed record writer") def _write_record(self, record): - fieldnames = self._fieldnames if fieldnames is None: self._fieldnames = fieldnames = list(record.keys()) - self._fieldnames.extend([i for i in self.custom_fields if i not in self._fieldnames]) - value_list = map(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames) + self._fieldnames.extend( + [i for i in self.custom_fields if i not in self._fieldnames] + ) + value_list = map(lambda fn: (str(fn), str("__mv_") + str(fn)), fieldnames) self._writerow(list(chain.from_iterable(value_list))) get_value = record.get @@ -593,40 +591,45 @@ def _write_record(self, record): value_t = type(value) if issubclass(value_t, (list, tuple)): - if len(value) == 0: values += (None, None) continue if len(value) > 1: value_list = value - sv = '' - mv = '$' + sv = "" + mv = "$" for value in value_list: - if value is None: - sv += '\n' - mv += '$;$' + sv += "\n" + mv += "$;$" continue value_t = type(value) if value_t is not bytes: - if value_t is bool: value = str(value.real) elif value_t is str: value = value - elif isinstance(value, int) or value_t is float or value_t is complex: + elif ( + isinstance(value, int) + or value_t is float + or value_t is complex + ): value = str(value) elif issubclass(value_t, (dict, list, tuple)): - value = str(''.join(RecordWriter._iterencode_json(value, 0))) + value = str( + "".join(RecordWriter._iterencode_json(value, 0)) + ) else: - value = repr(value).encode('utf-8', errors='backslashreplace') + value = repr(value).encode( + "utf-8", errors="backslashreplace" + ) - sv += value + '\n' - mv += value.replace('$', '$$') + '$;$' + sv += value + "\n" + mv += value.replace("$", "$$") + "$;$" values += (sv[:-1], mv[:-2]) continue @@ -651,7 +654,7 @@ def _write_record(self, record): continue if issubclass(value_t, dict): - values += (str(''.join(RecordWriter._iterencode_json(value, 0))), None) + values += (str("".join(RecordWriter._iterencode_json(value, 0))), None) continue values += (repr(value), None) @@ -667,58 +670,57 @@ def _write_record(self, record): from _json import make_encoder except ImportError: # We may be running under PyPy 2.5 which does not include the _json module - _iterencode_json = JSONEncoder(separators=(',', ':')).iterencode + _iterencode_json = JSONEncoder(separators=(",", ":")).iterencode else: # Creating _iterencode_json this way yields a two-fold performance improvement on Python 2.7.9 and 2.7.10 from json.encoder import encode_basestring_ascii @staticmethod def _default(o): - raise TypeError(repr(o) + ' is not JSON serializable') + raise TypeError(repr(o) + " is not JSON serializable") _iterencode_json = make_encoder( - {}, # markers (for detecting circular references) - _default, # object_encoder + {}, # markers (for detecting circular references) + _default, # object_encoder encode_basestring_ascii, # string_encoder - None, # indent - ':', ',', # separators - False, # sort_keys - False, # skip_keys - True # allow_nan + None, # indent + ":", + ",", # separators + False, # sort_keys + False, # skip_keys + True, # allow_nan ) del make_encoder class RecordWriterV1(RecordWriter): - def flush(self, finished=None, partial=None): + RecordWriter.flush( + self, finished, partial + ) # validates arguments and the state of this instance - RecordWriter.flush(self, finished, partial) # validates arguments and the state of this instance - - if self.pending_record_count > 0 or (self._chunk_count == 0 and 'messages' in self._inspector): - - messages = self._inspector.get('messages') + if self.pending_record_count > 0 or ( + self._chunk_count == 0 and "messages" in self._inspector + ): + messages = self._inspector.get("messages") if self._chunk_count == 0: - # Messages are written to the messages header when we write the first chunk of data # Guarantee: These messages are displayed by splunkweb and the job inspector if messages is not None: - message_level = RecordWriterV1._message_level.get for level, text in messages: self.write(message_level(level, level)) - self.write('=') + self.write("=") self.write(text) - self.write('\r\n') + self.write("\r\n") - self.write('\r\n') + self.write("\r\n") elif messages is not None: - # Messages are written to the messages header when we write subsequent chunks of data # Guarantee: These messages are displayed by splunkweb and the job inspector, if and only if the # command is configured with @@ -741,19 +743,19 @@ def flush(self, finished=None, partial=None): self._finished = finished is True _message_level = { - 'DEBUG': 'debug_message', - 'ERROR': 'error_message', - 'FATAL': 'error_message', - 'INFO': 'info_message', - 'WARN': 'warn_message' + "DEBUG": "debug_message", + "ERROR": "error_message", + "FATAL": "error_message", + "INFO": "info_message", + "WARN": "warn_message", } class RecordWriterV2(RecordWriter): - def flush(self, finished=None, partial=None): - - RecordWriter.flush(self, finished, partial) # validates arguments and the state of this instance + RecordWriter.flush( + self, finished, partial + ) # validates arguments and the state of this instance if partial or not finished: # Don't flush partial chunks, since the SCP v2 protocol does not @@ -781,43 +783,51 @@ def write_chunk(self, finished=None): if len(inspector) == 0: inspector = None - metadata = [('inspector', inspector), ('finished', finished)] + metadata = [("inspector", inspector), ("finished", finished)] self._write_chunk(metadata, self._buffer.getvalue()) self._clear() def write_metadata(self, configuration): self._ensure_validity() - metadata = chain(configuration.items(), (('inspector', self._inspector if self._inspector else None),)) - self._write_chunk(metadata, '') + metadata = chain( + configuration.items(), + (("inspector", self._inspector if self._inspector else None),), + ) + self._write_chunk(metadata, "") self._clear() def write_metric(self, name, value): self._ensure_validity() - self._inspector['metric.' + name] = value + self._inspector["metric." + name] = value def _clear(self): super()._clear() self._fieldnames = None def _write_chunk(self, metadata, body): - if metadata: - metadata = str(''.join(self._iterencode_json(dict((n, v) for n, v in metadata if v is not None), 0))) + metadata = str( + "".join( + self._iterencode_json( + dict((n, v) for n, v in metadata if v is not None), 0 + ) + ) + ) if sys.version_info >= (3, 0): - metadata = metadata.encode('utf-8') + metadata = metadata.encode("utf-8") metadata_length = len(metadata) else: metadata_length = 0 if sys.version_info >= (3, 0): - body = body.encode('utf-8') + body = body.encode("utf-8") body_length = len(body) if not (metadata_length > 0 or body_length > 0): return - start_line = f'chunked 1.0,{metadata_length},{body_length}\n' + start_line = f"chunked 1.0,{metadata_length},{body_length}\n" self.write(start_line) self.write(metadata) self.write(body) diff --git a/splunklib/searchcommands/reporting_command.py b/splunklib/searchcommands/reporting_command.py index e455a159a..39edebc79 100644 --- a/splunklib/searchcommands/reporting_command.py +++ b/splunklib/searchcommands/reporting_command.py @@ -24,7 +24,7 @@ class ReportingCommand(SearchCommand): - """ Processes search result records and generates a reporting data structure. + """Processes search result records and generates a reporting data structure. Reporting search commands run as either reduce or map/reduce operations. The reduce part runs on a search head and is responsible for processing a single chunk of search results to produce the command's reporting data structure. @@ -46,6 +46,7 @@ class ReportingCommand(SearchCommand): Splunk 6.3 or later. """ + # region Special methods def __init__(self): @@ -55,19 +56,23 @@ def __init__(self): # region Options - phase = Option(doc=''' + phase = Option( + doc=""" **Syntax:** phase=[map|reduce] **Description:** Identifies the phase of the current map-reduce operation. - ''', default='reduce', validate=Set('map', 'reduce')) + """, + default="reduce", + validate=Set("map", "reduce"), + ) # endregion # region Methods def map(self, records): - """ Override this method to compute partial results. + """Override this method to compute partial results. :param records: :type records: @@ -83,28 +88,32 @@ def _has_custom_method(self, method_name): return callable(method) and (method is not base_method) def prepare(self): - if self.phase == 'map': - if self._has_custom_method('map'): - phase_method = getattr(self.__class__, 'map') + if self.phase == "map": + if self._has_custom_method("map"): + phase_method = getattr(self.__class__, "map") self._configuration = phase_method.ConfigurationSettings(self) else: self._configuration = self.ConfigurationSettings(self) return - if self.phase == 'reduce': - streaming_preop = chain((self.name, 'phase="map"', str(self._options)), self.fieldnames) - self._configuration.streaming_preop = ' '.join(streaming_preop) + if self.phase == "reduce": + streaming_preop = chain( + (self.name, 'phase="map"', str(self._options)), self.fieldnames + ) + self._configuration.streaming_preop = " ".join(streaming_preop) return - raise RuntimeError(f'Unrecognized reporting command phase: {json_encode_string(str(self.phase))}') + raise RuntimeError( + f"Unrecognized reporting command phase: {json_encode_string(str(self.phase))}" + ) def reduce(self, records): - """ Override this method to produce a reporting data structure. + """Override this method to produce a reporting data structure. You must override this method. """ - raise NotImplementedError('reduce(self, records)') + raise NotImplementedError("reduce(self, records)") def _execute(self, ifile, process): SearchCommand._execute(self, ifile, getattr(self, self.phase)) @@ -114,12 +123,12 @@ def _execute(self, ifile, process): # region Types class ConfigurationSettings(SearchCommand.ConfigurationSettings): - """ Represents the configuration settings for a :code:`ReportingCommand`. + """Represents the configuration settings for a :code:`ReportingCommand`.""" - """ # region SCP v1/v2 Properties - required_fields = ConfigurationSetting(doc=''' + required_fields = ConfigurationSetting( + doc=""" List of required fields for this search which back-propagates to the generating search. Setting this value enables selected fields mode under SCP 2. Under SCP 1 you must also specify @@ -130,9 +139,11 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1, SCP 2 - ''') + """ + ) - requires_preop = ConfigurationSetting(doc=''' + requires_preop = ConfigurationSetting( + doc=""" Indicates whether :meth:`ReportingCommand.map` is required for proper command execution. If :const:`True`, :meth:`ReportingCommand.map` is guaranteed to be called. If :const:`False`, Splunk @@ -142,22 +153,26 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1, SCP 2 - ''') + """ + ) - streaming_preop = ConfigurationSetting(doc=''' + streaming_preop = ConfigurationSetting( + doc=""" Denotes the requested streaming preop search string. Computed. Supported by: SCP 1, SCP 2 - ''') + """ + ) # endregion # region SCP v1 Properties - clear_required_fields = ConfigurationSetting(doc=''' + clear_required_fields = ConfigurationSetting( + doc=""" :const:`True`, if required_fields represent the *only* fields required. If :const:`False`, required_fields are additive to any fields that may be required by subsequent commands. @@ -167,31 +182,41 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1 - ''') + """ + ) - retainsevents = ConfigurationSetting(readonly=True, value=False, doc=''' + retainsevents = ConfigurationSetting( + readonly=True, + value=False, + doc=""" Signals that :meth:`ReportingCommand.reduce` transforms _raw events to produce a reporting data structure. Fixed: :const:`False` Supported by: SCP 1 - ''') + """, + ) - streaming = ConfigurationSetting(readonly=True, value=False, doc=''' + streaming = ConfigurationSetting( + readonly=True, + value=False, + doc=""" Signals that :meth:`ReportingCommand.reduce` runs on the search head. Fixed: :const:`False` Supported by: SCP 1 - ''') + """, + ) # endregion # region SCP v2 Properties - maxinputs = ConfigurationSetting(doc=''' + maxinputs = ConfigurationSetting( + doc=""" Specifies the maximum number of events that can be passed to the command for each invocation. This limit cannot exceed the value of `maxresultrows` in limits.conf_. Under SCP 1 you must specify this @@ -203,9 +228,11 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): .. _limits.conf: http://docs.splunk.com/Documentation/Splunk/latest/admin/Limitsconf - ''') + """ + ) - run_in_preview = ConfigurationSetting(doc=''' + run_in_preview = ConfigurationSetting( + doc=""" :const:`True`, if this command should be run to generate results for preview; not wait for final output. This may be important for commands that have side effects (e.g., outputlookup). @@ -214,16 +241,21 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 2 - ''') + """ + ) - type = ConfigurationSetting(readonly=True, value='reporting', doc=''' + type = ConfigurationSetting( + readonly=True, + value="reporting", + doc=""" Command type name. Fixed: :const:`'reporting'`. Supported by: SCP 2 - ''') + """, + ) # endregion @@ -231,7 +263,7 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod def fix_up(cls, command): - """ Verifies :code:`command` class structure and configures the :code:`command.map` method. + """Verifies :code:`command` class structure and configures the :code:`command.map` method. Verifies that :code:`command` derives from :class:`ReportingCommand` and overrides :code:`ReportingCommand.reduce`. It then configures :code:`command.reduce`, if an overriding implementation @@ -246,16 +278,16 @@ def fix_up(cls, command): """ if not issubclass(command, ReportingCommand): - raise TypeError(f'{command} is not a ReportingCommand') + raise TypeError(f"{command} is not a ReportingCommand") if command.reduce == ReportingCommand.reduce: - raise AttributeError('No ReportingCommand.reduce override') + raise AttributeError("No ReportingCommand.reduce override") if command.map == ReportingCommand.map: cls._requires_preop = False return - f = vars(command)['map'] # Function backing the map method + f = vars(command)["map"] # Function backing the map method # EXPLANATION OF PREVIOUS STATEMENT: There is no way to add custom attributes to methods. See [Why does # setattr fail on a method](http://stackoverflow.com/questions/7891277/why-does-setattr-fail-on-a-bound-method) for a discussion of this issue. @@ -268,15 +300,14 @@ def fix_up(cls, command): # Create new StreamingCommand.ConfigurationSettings class - module = command.__module__ + '.' + command.__name__ + '.map' - name = b'ConfigurationSettings' + module = command.__module__ + "." + command.__name__ + ".map" + name = b"ConfigurationSettings" bases = (StreamingCommand.ConfigurationSettings,) f.ConfigurationSettings = ConfigurationSettingsType(module, name, bases) ConfigurationSetting.fix_up(f.ConfigurationSettings, settings) del f._settings - # endregion # endregion diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index ba2a4dc01..2c4f2ab54 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -48,7 +48,8 @@ Recorder, RecordWriterV1, RecordWriterV2, - json_encode_string) + json_encode_string, +) from ..client import Service from ..utils import ensure_str @@ -76,17 +77,17 @@ class SearchCommand: - """ Represents a custom search command. - - """ + """Represents a custom search command.""" def __init__(self): - # Variables that may be used, but not altered by derived classes class_name = self.__class__.__name__ - self._logger, self._logging_configuration = getLogger(class_name), environment.logging_configuration + self._logger, self._logging_configuration = ( + getLogger(class_name), + environment.logging_configuration, + ) # Variables backing option/property values @@ -108,14 +109,19 @@ def __init__(self): self._allow_empty_input = True def __str__(self): - text = ' '.join(chain((type(self).name, str(self.options)), [] if self.fieldnames is None else self.fieldnames)) + text = " ".join( + chain( + (type(self).name, str(self.options)), + [] if self.fieldnames is None else self.fieldnames, + ) + ) return text # region Options @Option def logging_configuration(self): - """ **Syntax:** logging_configuration= + """**Syntax:** logging_configuration= **Description:** Loads an alternative logging configuration file for a command invocation. The logging configuration file must be in Python @@ -126,11 +132,13 @@ def logging_configuration(self): @logging_configuration.setter def logging_configuration(self, value): - self._logger, self._logging_configuration = environment.configure_logging(self.__class__.__name__, value) + self._logger, self._logging_configuration = environment.configure_logging( + self.__class__.__name__, value + ) @Option def logging_level(self): - """ **Syntax:** logging_level=[CRITICAL|ERROR|WARNING|INFO|DEBUG|NOTSET] + """**Syntax:** logging_level=[CRITICAL|ERROR|WARNING|INFO|DEBUG|NOTSET] **Description:** Sets the threshold for the logger of this command invocation. Logging messages less severe than `logging_level` will be ignored. @@ -146,12 +154,12 @@ def logging_level(self, value): try: level = _levelNames[value.upper()] except KeyError: - raise ValueError(f'Unrecognized logging level: {value}') + raise ValueError(f"Unrecognized logging level: {value}") else: try: level = int(value) except ValueError: - raise ValueError(f'Unrecognized logging level: {value}') + raise ValueError(f"Unrecognized logging level: {value}") self._logger.setLevel(level) def add_field(self, current_record, field_name, field_value): @@ -162,19 +170,27 @@ def gen_record(self, **record): self._record_writer.custom_fields |= set(record.keys()) return record - record = Option(doc=''' + record = Option( + doc=""" **Syntax: record= **Description:** When `true`, records the interaction between the command and splunkd. Defaults to `false`. - ''', default=False, validate=Boolean()) + """, + default=False, + validate=Boolean(), + ) - show_configuration = Option(doc=''' + show_configuration = Option( + doc=""" **Syntax:** show_configuration= **Description:** When `true`, reports command configuration as an informational message. Defaults to `false`. - ''', default=False, validate=Boolean()) + """, + default=False, + validate=Boolean(), + ) # endregion @@ -182,16 +198,12 @@ def gen_record(self, **record): @property def configuration(self): - """ Returns the configuration settings for this command. - - """ + """Returns the configuration settings for this command.""" return self._configuration @property def fieldnames(self): - """ Returns the fieldnames specified as argument to this command. - - """ + """Returns the fieldnames specified as argument to this command.""" return self._fieldnames @fieldnames.setter @@ -200,20 +212,23 @@ def fieldnames(self, value): @property def input_header(self): - """ Returns the input header for this command. + """Returns the input header for this command. :return: The input header for this command. :rtype: InputHeader """ warn( - 'SearchCommand.input_header is deprecated and will be removed in a future release. ' - 'Please use SearchCommand.metadata instead.', DeprecationWarning, 2) + "SearchCommand.input_header is deprecated and will be removed in a future release. " + "Please use SearchCommand.metadata instead.", + DeprecationWarning, + 2, + ) return self._input_header @property def logger(self): - """ Returns the logger for this command. + """Returns the logger for this command. :return: The logger for this command. :rtype: @@ -227,9 +242,7 @@ def metadata(self): @property def options(self): - """ Returns the options specified as argument to this command. - - """ + """Returns the options specified as argument to this command.""" if self._options is None: self._options = Option.View(self) return self._options @@ -240,7 +253,7 @@ def protocol_version(self): @property def search_results_info(self): - """ Returns the search results info for this command invocation. + """Returns the search results info for this command invocation. The search results info object is created from the search results info file associated with the command invocation. @@ -255,7 +268,7 @@ def search_results_info(self): if self._protocol_version == 1: try: - path = self._input_header['infoPath'] + path = self._input_header["infoPath"] except KeyError: return None else: @@ -266,21 +279,23 @@ def search_results_info(self): except AttributeError: return None - path = os.path.join(dispatch_dir, 'info.csv') + path = os.path.join(dispatch_dir, "info.csv") try: - with io.open(path, 'r') as f: + with io.open(path, "r") as f: reader = csv.reader(f, dialect=CsvDialect) fields = next(reader) values = next(reader) except IOError as error: if error.errno == 2: - self.logger.error(f'Search results info file {json_encode_string(path)} does not exist.') + self.logger.error( + f"Search results info file {json_encode_string(path)} does not exist." + ) return raise def convert_field(field): - return (field[1:] if field[0] == '_' else field).replace('.', '_') + return (field[1:] if field[0] == "_" else field).replace(".", "_") decode = MetadataDecoder().decode @@ -290,16 +305,23 @@ def convert_value(value): except ValueError: return value - info = ObjectView(dict((convert_field(f_v[0]), convert_value(f_v[1])) for f_v in zip(fields, values))) + info = ObjectView( + dict( + (convert_field(f_v[0]), convert_value(f_v[1])) + for f_v in zip(fields, values) + ) + ) try: count_map = info.countMap except AttributeError: pass else: - count_map = count_map.split(';') + count_map = count_map.split(";") n = len(count_map) - info.countMap = dict(list(zip(islice(count_map, 0, n, 2), islice(count_map, 1, n, 2)))) + info.countMap = dict( + list(zip(islice(count_map, 0, n, 2), islice(count_map, 1, n, 2))) + ) try: msg_type = info.msgType @@ -307,7 +329,11 @@ def convert_value(value): except AttributeError: pass else: - messages = [t_m for t_m in zip(msg_type.split('\n'), msg_text.split('\n')) if t_m[0] or t_m[1]] + messages = [ + t_m + for t_m in zip(msg_type.split("\n"), msg_text.split("\n")) + if t_m[0] or t_m[1] + ] info.msg = [Message(message) for message in messages] del info.msgType @@ -321,7 +347,7 @@ def convert_value(value): @property def service(self): - """ Returns a Splunk service object for this command invocation or None. + """Returns a Splunk service object for this command invocation or None. The service object is created from the Splunkd URI and authentication token passed to the command invocation in the search results info file. This data is not passed to a command invocation by default. You must request it by @@ -359,14 +385,20 @@ def service(self): splunkd_uri = searchinfo.splunkd_uri if splunkd_uri is None or splunkd_uri == "" or splunkd_uri == " ": - self.logger.warning(f"Incorrect value for Splunkd URI: {splunkd_uri!r} in metadata") + self.logger.warning( + f"Incorrect value for Splunkd URI: {splunkd_uri!r} in metadata" + ) return None - uri = urlsplit(splunkd_uri, allow_fragments=False) self._service = Service( - scheme=uri.scheme, host=uri.hostname, port=uri.port, app=searchinfo.app, token=searchinfo.session_key) + scheme=uri.scheme, + host=uri.hostname, + port=uri.port, + app=searchinfo.app, + token=searchinfo.session_key, + ) return self._service @@ -376,11 +408,11 @@ def service(self): def error_exit(self, error, message=None): self.write_error(error.message if message is None else message) - self.logger.error('Abnormal exit: %s', error) + self.logger.error("Abnormal exit: %s", error) exit(1) def finish(self): - """ Flushes the output buffer and signals that this command has finished processing data. + """Flushes the output buffer and signals that this command has finished processing data. :return: :const:`None` @@ -388,7 +420,7 @@ def finish(self): self._record_writer.flush(finished=True) def flush(self): - """ Flushes the output buffer. + """Flushes the output buffer. :return: :const:`None` @@ -396,7 +428,7 @@ def flush(self): self._record_writer.flush(finished=False) def prepare(self): - """ Prepare for execution. + """Prepare for execution. This method should be overridden in search command classes that wish to examine and update their configuration or option settings prior to execution. It is called during the getinfo exchange before command metadata is sent @@ -407,8 +439,10 @@ def prepare(self): """ - def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True): - """ Process data. + def process( + self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True + ): + """Process data. :param argv: Command line arguments. :type argv: list or tuple @@ -439,17 +473,20 @@ def _map_input_header(self): searchinfo = metadata.searchinfo self._input_header.update( allowStream=None, - infoPath=os.path.join(searchinfo.dispatch_dir, 'info.csv'), + infoPath=os.path.join(searchinfo.dispatch_dir, "info.csv"), keywords=None, preview=metadata.preview, realtime=searchinfo.earliest_time != 0 and searchinfo.latest_time != 0, search=searchinfo.search, sid=searchinfo.sid, splunkVersion=searchinfo.splunk_version, - truncated=None) + truncated=None, + ) def _map_metadata(self, argv): - source = SearchCommand._MetadataSource(argv, self._input_header, self.search_results_info) + source = SearchCommand._MetadataSource( + argv, self._input_header, self.search_results_info + ) def _map(metadata_map): metadata = {} @@ -472,43 +509,43 @@ def _map(metadata_map): self._metadata = _map(SearchCommand._metadata_map) _metadata_map = { - 'action': - (lambda v: 'getinfo' if v == '__GETINFO__' else 'execute' if v == '__EXECUTE__' else None, - lambda s: s.argv[1]), - 'preview': - (bool, lambda s: s.input_header.get('preview')), - 'searchinfo': { - 'app': - (lambda v: v.ppc_app, lambda s: s.search_results_info), - 'args': - (None, lambda s: s.argv), - 'dispatch_dir': - (os.path.dirname, lambda s: s.input_header.get('infoPath')), - 'earliest_time': - (lambda v: float(v.rt_earliest) if len(v.rt_earliest) > 0 else 0.0, lambda s: s.search_results_info), - 'latest_time': - (lambda v: float(v.rt_latest) if len(v.rt_latest) > 0 else 0.0, lambda s: s.search_results_info), - 'owner': - (None, None), - 'raw_args': - (None, lambda s: s.argv), - 'search': - (unquote, lambda s: s.input_header.get('search')), - 'session_key': - (lambda v: v.auth_token, lambda s: s.search_results_info), - 'sid': - (None, lambda s: s.input_header.get('sid')), - 'splunk_version': - (None, lambda s: s.input_header.get('splunkVersion')), - 'splunkd_uri': - (lambda v: v.splunkd_uri, lambda s: s.search_results_info), - 'username': - (lambda v: v.ppc_user, lambda s: s.search_results_info)}} - - _MetadataSource = namedtuple('Source', ('argv', 'input_header', 'search_results_info')) + "action": ( + lambda v: "getinfo" + if v == "__GETINFO__" + else "execute" + if v == "__EXECUTE__" + else None, + lambda s: s.argv[1], + ), + "preview": (bool, lambda s: s.input_header.get("preview")), + "searchinfo": { + "app": (lambda v: v.ppc_app, lambda s: s.search_results_info), + "args": (None, lambda s: s.argv), + "dispatch_dir": (os.path.dirname, lambda s: s.input_header.get("infoPath")), + "earliest_time": ( + lambda v: float(v.rt_earliest) if len(v.rt_earliest) > 0 else 0.0, + lambda s: s.search_results_info, + ), + "latest_time": ( + lambda v: float(v.rt_latest) if len(v.rt_latest) > 0 else 0.0, + lambda s: s.search_results_info, + ), + "owner": (None, None), + "raw_args": (None, lambda s: s.argv), + "search": (unquote, lambda s: s.input_header.get("search")), + "session_key": (lambda v: v.auth_token, lambda s: s.search_results_info), + "sid": (None, lambda s: s.input_header.get("sid")), + "splunk_version": (None, lambda s: s.input_header.get("splunkVersion")), + "splunkd_uri": (lambda v: v.splunkd_uri, lambda s: s.search_results_info), + "username": (lambda v: v.ppc_user, lambda s: s.search_results_info), + }, + } + + _MetadataSource = namedtuple( + "Source", ("argv", "input_header", "search_results_info") + ) def _prepare_protocol_v1(self, argv, ifile, ofile): - debug = environment.splunklib_logger.debug # Provide as much context as possible in advance of parsing the command line and preparing for execution @@ -517,14 +554,16 @@ def _prepare_protocol_v1(self, argv, ifile, ofile): self._protocol_version = 1 self._map_metadata(argv) - debug(' metadata=%r, input_header=%r', self._metadata, self._input_header) + debug(" metadata=%r, input_header=%r", self._metadata, self._input_header) try: tempfile.tempdir = self._metadata.searchinfo.dispatch_dir except AttributeError: - raise RuntimeError(f'{self.__class__.__name__}.metadata.searchinfo.dispatch_dir is undefined') + raise RuntimeError( + f"{self.__class__.__name__}.metadata.searchinfo.dispatch_dir is undefined" + ) - debug(' tempfile.tempdir=%r', tempfile.tempdir) + debug(" tempfile.tempdir=%r", tempfile.tempdir) CommandLineParser.parse(self, argv[2:]) self.prepare() @@ -532,92 +571,116 @@ def _prepare_protocol_v1(self, argv, ifile, ofile): if self.record: self.record = False - record_argv = [argv[0], argv[1], str(self._options), ' '.join(self.fieldnames)] + record_argv = [ + argv[0], + argv[1], + str(self._options), + " ".join(self.fieldnames), + ] ifile, ofile = self._prepare_recording(record_argv, ifile, ofile) self._record_writer.ofile = ofile - ifile.record(str(self._input_header), '\n\n') + ifile.record(str(self._input_header), "\n\n") if self.show_configuration: - self.write_info(self.name + ' command configuration: ' + str(self._configuration)) + self.write_info( + self.name + " command configuration: " + str(self._configuration) + ) return ifile # wrapped, if self.record is True def _prepare_recording(self, argv, ifile, ofile): - # Create the recordings directory, if it doesn't already exist - recordings = os.path.join(environment.splunk_home, 'var', 'run', 'splunklib.searchcommands', 'recordings') + recordings = os.path.join( + environment.splunk_home, + "var", + "run", + "splunklib.searchcommands", + "recordings", + ) if not os.path.isdir(recordings): os.makedirs(recordings) # Create input/output recorders from ifile and ofile - recording = os.path.join(recordings, self.__class__.__name__ + '-' + repr(time()) + '.' + self._metadata.action) - ifile = Recorder(recording + '.input', ifile) - ofile = Recorder(recording + '.output', ofile) + recording = os.path.join( + recordings, + self.__class__.__name__ + "-" + repr(time()) + "." + self._metadata.action, + ) + ifile = Recorder(recording + ".input", ifile) + ofile = Recorder(recording + ".output", ofile) # Archive the dispatch directory--if it exists--so that it can be used as a baseline in mocks) dispatch_dir = self._metadata.searchinfo.dispatch_dir - if dispatch_dir is not None: # __GETINFO__ action does not include a dispatch_dir + if ( + dispatch_dir is not None + ): # __GETINFO__ action does not include a dispatch_dir root_dir, base_dir = os.path.split(dispatch_dir) - make_archive(recording + '.dispatch_dir', 'gztar', root_dir, base_dir, logger=self.logger) + make_archive( + recording + ".dispatch_dir", + "gztar", + root_dir, + base_dir, + logger=self.logger, + ) # Save a splunk command line because it is useful for developing tests - with open(recording + '.splunk_cmd', 'wb') as f: - f.write('splunk cmd python '.encode()) + with open(recording + ".splunk_cmd", "wb") as f: + f.write("splunk cmd python ".encode()) f.write(os.path.basename(argv[0]).encode()) for arg in islice(argv, 1, len(argv)): - f.write(' '.encode()) + f.write(" ".encode()) f.write(arg.encode()) return ifile, ofile def _process_protocol_v1(self, argv, ifile, ofile): - debug = environment.splunklib_logger.debug class_name = self.__class__.__name__ - debug('%s.process started under protocol_version=1', class_name) + debug("%s.process started under protocol_version=1", class_name) self._record_writer = RecordWriterV1(ofile) # noinspection PyBroadException try: - if argv[1] == '__GETINFO__': - - debug('Writing configuration settings') + if argv[1] == "__GETINFO__": + debug("Writing configuration settings") ifile = self._prepare_protocol_v1(argv, ifile, ofile) - self._record_writer.write_record(dict( - (n, ','.join(v) if isinstance(v, (list, tuple)) else v) for n, v in - self._configuration.items())) + self._record_writer.write_record( + dict( + (n, ",".join(v) if isinstance(v, (list, tuple)) else v) + for n, v in self._configuration.items() + ) + ) self.finish() - elif argv[1] == '__EXECUTE__': - - debug('Executing') + elif argv[1] == "__EXECUTE__": + debug("Executing") ifile = self._prepare_protocol_v1(argv, ifile, ofile) self._records = self._records_protocol_v1 - self._metadata.action = 'execute' + self._metadata.action = "execute" self._execute(ifile, None) else: message = ( - f'Command {self.name} appears to be statically configured for search command protocol version 1 and static ' - 'configuration is unsupported by splunklib.searchcommands. Please ensure that ' - 'default/commands.conf contains this stanza:\n' - f'[{self.name}]\n' - f'filename = {os.path.basename(argv[0])}\n' - 'enableheader = true\n' - 'outputheader = true\n' - 'requires_srinfo = true\n' - 'supports_getinfo = true\n' - 'supports_multivalues = true\n' - 'supports_rawargs = true') + f"Command {self.name} appears to be statically configured for search command protocol version 1 and static " + "configuration is unsupported by splunklib.searchcommands. Please ensure that " + "default/commands.conf contains this stanza:\n" + f"[{self.name}]\n" + f"filename = {os.path.basename(argv[0])}\n" + "enableheader = true\n" + "outputheader = true\n" + "requires_srinfo = true\n" + "supports_getinfo = true\n" + "supports_multivalues = true\n" + "supports_rawargs = true" + ) raise RuntimeError(message) except (SyntaxError, ValueError) as error: @@ -634,23 +697,23 @@ def _process_protocol_v1(self, argv, ifile, ofile): self.flush() exit(1) - debug('%s.process finished under protocol_version=1', class_name) + debug("%s.process finished under protocol_version=1", class_name) def _protocol_v2_option_parser(self, arg): - """ Determines if an argument is an Option/Value pair, or just a Positional Argument. - Method so different search commands can handle parsing of arguments differently. + """Determines if an argument is an Option/Value pair, or just a Positional Argument. + Method so different search commands can handle parsing of arguments differently. - :param arg: A single argument provided to the command from SPL - :type arg: str + :param arg: A single argument provided to the command from SPL + :type arg: str - :return: [OptionName, OptionValue] OR [PositionalArgument] - :rtype: List[str] + :return: [OptionName, OptionValue] OR [PositionalArgument] + :rtype: List[str] """ - return arg.split('=', 1) + return arg.split("=", 1) def _process_protocol_v2(self, argv, ifile, ofile): - """ Processes records on the `input stream optionally writing records to the output stream. + """Processes records on the `input stream optionally writing records to the output stream. :param ifile: Input file object. :type ifile: file or InputType @@ -664,22 +727,22 @@ def _process_protocol_v2(self, argv, ifile, ofile): debug = environment.splunklib_logger.debug class_name = self.__class__.__name__ - debug('%s.process started under protocol_version=2', class_name) + debug("%s.process started under protocol_version=2", class_name) self._protocol_version = 2 # Read search command metadata from splunkd # noinspection PyBroadException try: - debug('Reading metadata') + debug("Reading metadata") metadata, body = self._read_chunk(self._as_binary_stream(ifile)) - action = getattr(metadata, 'action', None) + action = getattr(metadata, "action", None) - if action != 'getinfo': - raise RuntimeError(f'Expected getinfo action, not {action}') + if action != "getinfo": + raise RuntimeError(f"Expected getinfo action, not {action}") if len(body) > 0: - raise RuntimeError('Did not expect data for getinfo action') + raise RuntimeError("Did not expect data for getinfo action") self._metadata = deepcopy(metadata) @@ -691,14 +754,16 @@ def _process_protocol_v2(self, argv, ifile, ofile): self._map_input_header() - debug(' metadata=%r, input_header=%r', self._metadata, self._input_header) + debug(" metadata=%r, input_header=%r", self._metadata, self._input_header) try: tempfile.tempdir = self._metadata.searchinfo.dispatch_dir except AttributeError: - raise RuntimeError(f'{class_name}.metadata.searchinfo.dispatch_dir is undefined') + raise RuntimeError( + f"{class_name}.metadata.searchinfo.dispatch_dir is undefined" + ) - debug(' tempfile.tempdir=%r', tempfile.tempdir) + debug(" tempfile.tempdir=%r", tempfile.tempdir) except: self._record_writer = RecordWriterV2(ofile) self._report_unexpected_error() @@ -708,14 +773,16 @@ def _process_protocol_v2(self, argv, ifile, ofile): # Write search command configuration for consumption by splunkd # noinspection PyBroadException try: - self._record_writer = RecordWriterV2(ofile, getattr(self._metadata.searchinfo, 'maxresultrows', None)) + self._record_writer = RecordWriterV2( + ofile, getattr(self._metadata.searchinfo, "maxresultrows", None) + ) self.fieldnames = [] self.options.reset() args = self.metadata.searchinfo.args error_count = 0 - debug('Parsing arguments') + debug("Parsing arguments") if args and isinstance(args, list): for arg in args: @@ -728,13 +795,13 @@ def _process_protocol_v2(self, argv, ifile, ofile): try: option = self.options[name] except KeyError: - self.write_error(f'Unrecognized option: {name}={value}') + self.write_error(f"Unrecognized option: {name}={value}") error_count += 1 continue try: option.value = value except ValueError: - self.write_error(f'Illegal value: {name}={value}') + self.write_error(f"Illegal value: {name}={value}") error_count += 1 continue @@ -744,19 +811,20 @@ def _process_protocol_v2(self, argv, ifile, ofile): if len(missing) == 1: self.write_error(f'A value for "{missing[0]}" is required') else: - self.write_error(f'Values for these required options are missing: {", ".join(missing)}') + self.write_error( + f"Values for these required options are missing: {', '.join(missing)}" + ) error_count += 1 if error_count > 0: exit(1) - debug(' command: %s', str(self)) + debug(" command: %s", str(self)) - debug('Preparing for execution') + debug("Preparing for execution") self.prepare() if self.record: - ifile, ofile = self._prepare_recording(argv, ifile, ofile) self._record_writer.ofile = ofile @@ -764,16 +832,26 @@ def _process_protocol_v2(self, argv, ifile, ofile): info = self._metadata.searchinfo - for attr in 'args', 'raw_args': - setattr(info, attr, [arg for arg in getattr(info, attr) if not arg.startswith('record=')]) + for attr in "args", "raw_args": + setattr( + info, + attr, + [ + arg + for arg in getattr(info, attr) + if not arg.startswith("record=") + ], + ) metadata = MetadataEncoder().encode(self._metadata) - ifile.record('chunked 1.0,', str(len(metadata)), ',0\n', metadata) + ifile.record("chunked 1.0,", str(len(metadata)), ",0\n", metadata) if self.show_configuration: - self.write_info(self.name + ' command configuration: ' + str(self._configuration)) + self.write_info( + self.name + " command configuration: " + str(self._configuration) + ) - debug(' command configuration: %s', self._configuration) + debug(" command configuration: %s", self._configuration) except SystemExit: self._record_writer.write_metadata(self._configuration) @@ -790,8 +868,8 @@ def _process_protocol_v2(self, argv, ifile, ofile): # Execute search command on data passing through the pipeline # noinspection PyBroadException try: - debug('Executing under protocol_version=2') - self._metadata.action = 'execute' + debug("Executing under protocol_version=2") + self._metadata.action = "execute" self._execute(ifile, None) except SystemExit: self.finish() @@ -801,25 +879,25 @@ def _process_protocol_v2(self, argv, ifile, ofile): self.finish() exit(1) - debug('%s.process completed', class_name) + debug("%s.process completed", class_name) def write_debug(self, message, *args): - self._record_writer.write_message('DEBUG', message, *args) + self._record_writer.write_message("DEBUG", message, *args) def write_error(self, message, *args): - self._record_writer.write_message('ERROR', message, *args) + self._record_writer.write_message("ERROR", message, *args) def write_fatal(self, message, *args): - self._record_writer.write_message('FATAL', message, *args) + self._record_writer.write_message("FATAL", message, *args) def write_info(self, message, *args): - self._record_writer.write_message('INFO', message, *args) + self._record_writer.write_message("INFO", message, *args) def write_warning(self, message, *args): - self._record_writer.write_message('WARN', message, *args) + self._record_writer.write_message("WARN", message, *args) def write_metric(self, name, value): - """ Writes a metric that will be added to the search inspector. + """Writes a metric that will be added to the search inspector. :param name: Name of the metric. :type name: basestring @@ -843,14 +921,19 @@ def write_metric(self, name, value): @staticmethod def _decode_list(mv): - return [match.replace('$$', '$') for match in SearchCommand._encoded_value.findall(mv)] + return [ + match.replace("$$", "$") + for match in SearchCommand._encoded_value.findall(mv) + ] - _encoded_value = re.compile(r'\$(?P(?:\$\$|[^$])*)\$(?:;|$)') # matches a single value in an encoded list + _encoded_value = re.compile( + r"\$(?P(?:\$\$|[^$])*)\$(?:;|$)" + ) # matches a single value in an encoded list # Note: Subclasses must override this method so that it can be called # called as self._execute(ifile, None) def _execute(self, ifile, process): - """ Default processing loop + """Default processing loop :param ifile: Input file object. :type ifile: file @@ -878,17 +961,17 @@ def _as_binary_stream(ifile): try: return ifile.buffer except AttributeError as error: - raise RuntimeError(f'Failed to get underlying buffer: {error}') + raise RuntimeError(f"Failed to get underlying buffer: {error}") @staticmethod def _read_chunk(istream): # noinspection PyBroadException - assert isinstance(istream.read(0), bytes), 'Stream must be binary' + assert isinstance(istream.read(0), bytes), "Stream must be binary" try: header = istream.readline() except Exception as error: - raise RuntimeError(f'Failed to read transport header: {error}') + raise RuntimeError(f"Failed to read transport header: {error}") if not header: return None @@ -896,7 +979,7 @@ def _read_chunk(istream): match = SearchCommand._header.match(ensure_str(header)) if match is None: - raise RuntimeError(f'Failed to parse transport header: {header}') + raise RuntimeError(f"Failed to parse transport header: {header}") metadata_length, body_length = match.groups() metadata_length = int(metadata_length) @@ -905,14 +988,18 @@ def _read_chunk(istream): try: metadata = istream.read(metadata_length) except Exception as error: - raise RuntimeError(f'Failed to read metadata of length {metadata_length}: {error}') + raise RuntimeError( + f"Failed to read metadata of length {metadata_length}: {error}" + ) decoder = MetadataDecoder() try: metadata = decoder.decode(ensure_str(metadata)) except Exception as error: - raise RuntimeError(f'Failed to parse metadata of length {metadata_length}: {error}') + raise RuntimeError( + f"Failed to parse metadata of length {metadata_length}: {error}" + ) # if body_length <= 0: # return metadata, '' @@ -922,11 +1009,11 @@ def _read_chunk(istream): if body_length > 0: body = istream.read(body_length) except Exception as error: - raise RuntimeError(f'Failed to read body of length {body_length}: {error}') + raise RuntimeError(f"Failed to read body of length {body_length}: {error}") - return metadata, ensure_str(body,errors="replace") + return metadata, ensure_str(body, errors="replace") - _header = re.compile(r'chunked\s+1.0\s*,\s*(\d+)\s*,\s*(\d+)\s*\n') + _header = re.compile(r"chunked\s+1.0\s*,\s*(\d+)\s*,\s*(\d+)\s*\n") def _records_protocol_v1(self, ifile): return self._read_csv_records(ifile) @@ -939,7 +1026,11 @@ def _read_csv_records(self, ifile): except StopIteration: return - mv_fieldnames = dict((name, name[len('__mv_'):]) for name in fieldnames if name.startswith('__mv_')) + mv_fieldnames = dict( + (name, name[len("__mv_") :]) + for name in fieldnames + if name.startswith("__mv_") + ) if len(mv_fieldnames) == 0: for values in reader: @@ -949,7 +1040,7 @@ def _read_csv_records(self, ifile): for values in reader: record = OrderedDict() for fieldname, value in zip(fieldnames, values): - if fieldname.startswith('__mv_'): + if fieldname.startswith("__mv_"): if len(value) > 0: record[mv_fieldnames[fieldname]] = self._decode_list(value) elif fieldname not in record: @@ -966,11 +1057,11 @@ def _execute_v2(self, ifile, process): return metadata, body = result - action = getattr(metadata, 'action', None) - if action != 'execute': - raise RuntimeError(f'Expected execute action, not {action}') + action = getattr(metadata, "action", None) + if action != "execute": + raise RuntimeError(f"Expected execute action, not {action}") - self._finished = getattr(metadata, 'finished', False) + self._finished = getattr(metadata, "finished", False) self._record_writer.is_flushed = False self._metadata.update(metadata) self._execute_chunk_v2(process, result) @@ -983,13 +1074,13 @@ def _execute_chunk_v2(self, process, chunk): if len(body) <= 0 and not self._allow_empty_input: raise ValueError( "No records found to process. Set allow_empty_input=True in dispatch function to move forward " - "with empty records.") + "with empty records." + ) records = self._read_csv_records(StringIO(body)) self._record_writer.write_records(process(records)) def _report_unexpected_error(self): - error_type, error, tb = sys.exc_info() origin = tb @@ -1000,7 +1091,9 @@ def _report_unexpected_error(self): lineno = origin.tb_lineno message = f'{error_type.__name__} at "{filename}", line {str(lineno)} : {error}' - environment.splunklib_logger.error(message + '\nTraceback:\n' + ''.join(traceback.format_tb(tb))) + environment.splunklib_logger.error( + message + "\nTraceback:\n" + "".join(traceback.format_tb(tb)) + ) self.write_error(message) # endregion @@ -1008,15 +1101,13 @@ def _report_unexpected_error(self): # region Types class ConfigurationSettings: - """ Represents the configuration settings common to all :class:`SearchCommand` classes. - - """ + """Represents the configuration settings common to all :class:`SearchCommand` classes.""" def __init__(self, command): self.command = command def __repr__(self): - """ Converts the value of this instance to its string representation. + """Converts the value of this instance to its string representation. The value of this ConfigurationSettings instance is represented as a string of comma-separated :code:`(name, value)` pairs. @@ -1025,12 +1116,16 @@ def __repr__(self): """ definitions = type(self).configuration_setting_definitions - settings = [repr((setting.name, setting.__get__(self), setting.supporting_protocols)) for setting in - definitions] - return '[' + ', '.join(settings) + ']' + settings = [ + repr( + (setting.name, setting.__get__(self), setting.supporting_protocols) + ) + for setting in definitions + ] + return "[" + ", ".join(settings) + "]" def __str__(self): - """ Converts the value of this instance to its string representation. + """Converts the value of this instance to its string representation. The value of this ConfigurationSettings instance is represented as a string of comma-separated :code:`name=value` pairs. Items with values of :const:`None` are filtered from the list. @@ -1039,14 +1134,19 @@ def __str__(self): """ # text = ', '.join(imap(lambda (name, value): name + '=' + json_encode_string(unicode(value)), self.iteritems())) - text = ', '.join([f'{name}={json_encode_string(str(value))}' for (name, value) in self.items()]) + text = ", ".join( + [ + f"{name}={json_encode_string(str(value))}" + for (name, value) in self.items() + ] + ) return text # region Methods @classmethod def fix_up(cls, command_class): - """ Adjusts and checks this class and its search command class. + """Adjusts and checks this class and its search command class. Derived classes typically override this method. It is used by the :decorator:`Configuration` decorator to fix up the :class:`SearchCommand` class it adorns. This method is overridden by :class:`EventingCommand`, @@ -1063,10 +1163,18 @@ def fix_up(cls, command_class): def iteritems(self): definitions = type(self).configuration_setting_definitions version = self.command.protocol_version - return [name_value1 for name_value1 in [(setting.name, setting.__get__(self)) for setting in - [setting for setting in definitions if - setting.is_supported_by_protocol(version)]] if - name_value1[1] is not None] + return [ + name_value1 + for name_value1 in [ + (setting.name, setting.__get__(self)) + for setting in [ + setting + for setting in definitions + if setting.is_supported_by_protocol(version) + ] + ] + if name_value1[1] is not None + ] # N.B.: Does not use Python 3 dict view semantics @@ -1077,12 +1185,21 @@ def iteritems(self): # endregion -SearchMetric = namedtuple('SearchMetric', ('elapsed_seconds', 'invocation_count', 'input_count', 'output_count')) +SearchMetric = namedtuple( + "SearchMetric", + ("elapsed_seconds", "invocation_count", "input_count", "output_count"), +) -def dispatch(command_class, argv=sys.argv, input_file=sys.stdin, output_file=sys.stdout, module_name=None, - allow_empty_input=True): - """ Instantiates and executes a search command class +def dispatch( + command_class, + argv=sys.argv, + input_file=sys.stdin, + output_file=sys.stdout, + module_name=None, + allow_empty_input=True, +): + """Instantiates and executes a search command class This function implements a `conditional script stanza `_ based on the value of :code:`module_name`:: @@ -1142,5 +1259,5 @@ def stream(records): """ assert issubclass(command_class, SearchCommand) - if module_name is None or module_name == '__main__': + if module_name is None or module_name == "__main__": command_class().process(argv, input_file, output_file, allow_empty_input) diff --git a/splunklib/searchcommands/streaming_command.py b/splunklib/searchcommands/streaming_command.py index e2a3a4077..4a2548d37 100644 --- a/splunklib/searchcommands/streaming_command.py +++ b/splunklib/searchcommands/streaming_command.py @@ -20,7 +20,7 @@ class StreamingCommand(SearchCommand): - """ Applies a transformation to search results as they travel through the streams pipeline. + """Applies a transformation to search results as they travel through the streams pipeline. Streaming commands typically filter, augment, or update, search result records. Splunk will send them in batches of up to 50,000 records. Hence, a search command must be prepared to be invoked many times during the course of @@ -37,15 +37,16 @@ class StreamingCommand(SearchCommand): Splunk 6.3 or later. """ + # region Methods def stream(self, records): - """ Generator function that processes and yields event records to the Splunk stream pipeline. + """Generator function that processes and yields event records to the Splunk stream pipeline. You must override this method. """ - raise NotImplementedError('StreamingCommand.stream(self, records)') + raise NotImplementedError("StreamingCommand.stream(self, records)") def _execute(self, ifile, process): SearchCommand._execute(self, ifile, self.stream) @@ -53,12 +54,12 @@ def _execute(self, ifile, process): # endregion class ConfigurationSettings(SearchCommand.ConfigurationSettings): - """ Represents the configuration settings that apply to a :class:`StreamingCommand`. + """Represents the configuration settings that apply to a :class:`StreamingCommand`.""" - """ # region SCP v1/v2 properties - required_fields = ConfigurationSetting(doc=''' + required_fields = ConfigurationSetting( + doc=""" List of required fields for this search which back-propagates to the generating search. Setting this value enables selected fields mode under SCP 2. Under SCP 1 you must also specify @@ -69,13 +70,15 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1, SCP 2 - ''') + """ + ) # endregion # region SCP v1 properties - clear_required_fields = ConfigurationSetting(doc=''' + clear_required_fields = ConfigurationSetting( + doc=""" :const:`True`, if required_fields represent the *only* fields required. If :const:`False`, required_fields are additive to any fields that may be required by subsequent commands. @@ -85,40 +88,51 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 1 - ''') + """ + ) - local = ConfigurationSetting(doc=''' + local = ConfigurationSetting( + doc=""" :const:`True`, if the command should run locally on the search head. Default: :const:`False` Supported by: SCP 1 - ''') + """ + ) - overrides_timeorder = ConfigurationSetting(doc=''' + overrides_timeorder = ConfigurationSetting( + doc=""" :const:`True`, if the command changes the order of events with respect to time. Default: :const:`False` Supported by: SCP 1 - ''') + """ + ) - streaming = ConfigurationSetting(readonly=True, value=True, doc=''' + streaming = ConfigurationSetting( + readonly=True, + value=True, + doc=""" Specifies that the command is streamable. Fixed: :const:`True` Supported by: SCP 1 - ''') + """, + ) # endregion # region SCP v2 Properties - distributed = ConfigurationSetting(value=True, doc=''' + distributed = ConfigurationSetting( + value=True, + doc=""" :const:`True`, if this command should be distributed to indexers. Under SCP 1 you must either specify `local = False` or include this line in commands.conf_, if this command @@ -133,9 +147,11 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): .. commands.conf_: http://docs.splunk.com/Documentation/Splunk/latest/Admin/Commandsconf - ''') + """, + ) - maxinputs = ConfigurationSetting(doc=''' + maxinputs = ConfigurationSetting( + doc=""" Specifies the maximum number of events that can be passed to the command for each invocation. This limit cannot exceed the value of `maxresultrows` in limits.conf. Under SCP 1 you must specify this @@ -145,16 +161,21 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): Supported by: SCP 2 - ''') + """ + ) - type = ConfigurationSetting(readonly=True, value='streaming', doc=''' + type = ConfigurationSetting( + readonly=True, + value="streaming", + doc=""" Command type name. Fixed: :const:`'streaming'` Supported by: SCP 2 - ''') + """, + ) # endregion @@ -162,11 +183,9 @@ class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod def fix_up(cls, command): - """ Verifies :code:`command` class structure. - - """ + """Verifies :code:`command` class structure.""" if command.stream == StreamingCommand.stream: - raise AttributeError('No StreamingCommand.stream override') + raise AttributeError("No StreamingCommand.stream override") # TODO: Stop looking like a dictionary because we don't obey the semantics # N.B.: Does not use Python 2 dict copy semantics @@ -175,11 +194,24 @@ def iteritems(self): version = self.command.protocol_version if version == 1: if self.required_fields is None: - iteritems = [name_value for name_value in iteritems if name_value[0] != 'clear_required_fields'] + iteritems = [ + name_value + for name_value in iteritems + if name_value[0] != "clear_required_fields" + ] else: - iteritems = [name_value2 for name_value2 in iteritems if name_value2[0] != 'distributed'] + iteritems = [ + name_value2 + for name_value2 in iteritems + if name_value2[0] != "distributed" + ] if not self.distributed: - iteritems = [(name_value1[0], 'stateful') if name_value1[0] == 'type' else (name_value1[0], name_value1[1]) for name_value1 in iteritems] + iteritems = [ + (name_value1[0], "stateful") + if name_value1[0] == "type" + else (name_value1[0], name_value1[1]) + for name_value1 in iteritems + ] return iteritems # N.B.: Does not use Python 3 dict view semantics diff --git a/splunklib/searchcommands/validators.py b/splunklib/searchcommands/validators.py index ccaebca0a..17cae428e 100644 --- a/splunklib/searchcommands/validators.py +++ b/splunklib/searchcommands/validators.py @@ -23,9 +23,8 @@ from collections import namedtuple - class Validator: - """ Base class for validators that check and format search command options. + """Base class for validators that check and format search command options. You must inherit from this class and override :code:`Validator.__call__` and :code:`Validator.format`. :code:`Validator.__call__` should convert the @@ -36,6 +35,7 @@ class Validator: it receives as argument the same way :code:`str` does. """ + def __call__(self, value): raise NotImplementedError() @@ -44,40 +44,45 @@ def format(self, value): class Boolean(Validator): - """ Validates Boolean option values. + """Validates Boolean option values.""" - """ truth_values = { - '1': True, '0': False, - 't': True, 'f': False, - 'true': True, 'false': False, - 'y': True, 'n': False, - 'yes': True, 'no': False + "1": True, + "0": False, + "t": True, + "f": False, + "true": True, + "false": False, + "y": True, + "n": False, + "yes": True, + "no": False, } def __call__(self, value): if not (value is None or isinstance(value, bool)): value = str(value).lower() if value not in Boolean.truth_values: - raise ValueError(f'Unrecognized truth value: {value}') + raise ValueError(f"Unrecognized truth value: {value}") value = Boolean.truth_values[value] return value def format(self, value): if value is None: return None - return 't' if value else 'f' + return "t" if value else "f" class Code(Validator): - """ Validates code option values. + """Validates code option values. This validator compiles an option value into a Python code object that can be executed by :func:`exec` or evaluated by :func:`eval`. The value returned is a :func:`namedtuple` with two members: object, the result of compilation, and source, the original option value. """ - def __init__(self, mode='eval'): + + def __init__(self, mode="eval"): """ :param mode: Specifies what kind of code must be compiled; it can be :const:`'exec'`, if source consists of a sequence of statements, :const:`'eval'`, if it consists of a single expression, or :const:`'single'` if it @@ -92,7 +97,7 @@ def __call__(self, value): if value is None: return None try: - return Code.object(compile(value, 'string', self._mode), str(value)) + return Code.object(compile(value, "string", self._mode), str(value)) except (SyntaxError, TypeError) as error: message = str(error) @@ -101,20 +106,19 @@ def __call__(self, value): def format(self, value): return None if value is None else value.source - object = namedtuple('Code', ('object', 'source')) + object = namedtuple("Code", ("object", "source")) class Fieldname(Validator): - """ Validates field name option values. + """Validates field name option values.""" - """ - pattern = re.compile(r'''[_.a-zA-Z-][_.a-zA-Z0-9-]*$''') + pattern = re.compile(r"""[_.a-zA-Z-][_.a-zA-Z0-9-]*$""") def __call__(self, value): if value is not None: value = str(value) if Fieldname.pattern.match(value) is None: - raise ValueError(f'Illegal characters in fieldname: {value}') + raise ValueError(f"Illegal characters in fieldname: {value}") return value def format(self, value): @@ -122,16 +126,14 @@ def format(self, value): class File(Validator): - """ Validates file option values. + """Validates file option values.""" - """ - def __init__(self, mode='rt', buffering=None, directory=None): + def __init__(self, mode="rt", buffering=None, directory=None): self.mode = mode self.buffering = buffering self.directory = File._var_run_splunk if directory is None else directory def __call__(self, value): - if value is None: return value @@ -141,9 +143,15 @@ def __call__(self, value): path = os.path.join(self.directory, path) try: - value = open(path, self.mode) if self.buffering is None else open(path, self.mode, self.buffering) + value = ( + open(path, self.mode) + if self.buffering is None + else open(path, self.mode, self.buffering) + ) except IOError as error: - raise ValueError(f'Cannot open {value} with mode={self.mode} and buffering={self.buffering}: {error}') + raise ValueError( + f"Cannot open {value} with mode={self.mode} and buffering={self.buffering}: {error}" + ) return value @@ -151,42 +159,54 @@ def format(self, value): return None if value is None else value.name _var_run_splunk = os.path.join( - os.environ['SPLUNK_HOME'] if 'SPLUNK_HOME' in os.environ else getcwd(), 'var', 'run', 'splunk') + os.environ["SPLUNK_HOME"] if "SPLUNK_HOME" in os.environ else getcwd(), + "var", + "run", + "splunk", + ) class Integer(Validator): - """ Validates integer option values. + """Validates integer option values.""" - """ def __init__(self, minimum=None, maximum=None): if minimum is not None and maximum is not None: + def check_range(value): if not minimum <= value <= maximum: - raise ValueError(f'Expected integer in the range [{minimum},{maximum}], not {value}') + raise ValueError( + f"Expected integer in the range [{minimum},{maximum}], not {value}" + ) elif minimum is not None: + def check_range(value): if value < minimum: - raise ValueError(f'Expected integer in the range [{minimum},+∞], not {value}') + raise ValueError( + f"Expected integer in the range [{minimum},+∞], not {value}" + ) elif maximum is not None: + def check_range(value): if value > maximum: - raise ValueError(f'Expected integer in the range [-∞,{maximum}], not {value}') + raise ValueError( + f"Expected integer in the range [-∞,{maximum}], not {value}" + ) else: + def check_range(value): return self.check_range = check_range - def __call__(self, value): if value is None: return None try: value = int(value) except ValueError: - raise ValueError(f'Expected integer value, not {json_encode_string(value)}') + raise ValueError(f"Expected integer value, not {json_encode_string(value)}") self.check_range(value) return value @@ -196,27 +216,36 @@ def format(self, value): class Float(Validator): - """ Validates float option values. + """Validates float option values.""" - """ def __init__(self, minimum=None, maximum=None): if minimum is not None and maximum is not None: + def check_range(value): if not minimum <= value <= maximum: - raise ValueError(f'Expected float in the range [{minimum},{maximum}], not {value}') + raise ValueError( + f"Expected float in the range [{minimum},{maximum}], not {value}" + ) elif minimum is not None: + def check_range(value): if value < minimum: - raise ValueError(f'Expected float in the range [{minimum},+∞], not {value}') + raise ValueError( + f"Expected float in the range [{minimum},+∞], not {value}" + ) elif maximum is not None: + def check_range(value): if value > maximum: - raise ValueError(f'Expected float in the range [-∞,{maximum}], not {value}') + raise ValueError( + f"Expected float in the range [-∞,{maximum}], not {value}" + ) else: + def check_range(value): return - self.check_range = check_range + self.check_range = check_range def __call__(self, value): if value is None: @@ -224,7 +253,7 @@ def __call__(self, value): try: value = float(value) except ValueError: - raise ValueError(f'Expected float value, not {json_encode_string(value)}') + raise ValueError(f"Expected float value, not {json_encode_string(value)}") self.check_range(value) return value @@ -234,15 +263,13 @@ def format(self, value): class Duration(Validator): - """ Validates duration option values. + """Validates duration option values.""" - """ def __call__(self, value): - if value is None: return None - p = value.split(':', 2) + p = value.split(":", 2) result = None _60 = Duration._60 _unsigned = Duration._unsigned @@ -255,12 +282,11 @@ def __call__(self, value): if len(p) == 3: result = 3600 * _unsigned(p[0]) + 60 * _60(p[1]) + _60(p[2]) except ValueError: - raise ValueError(f'Invalid duration value: {value}') + raise ValueError(f"Invalid duration value: {value}") return result def format(self, value): - if value is None: return None @@ -270,33 +296,34 @@ def format(self, value): m = value // 60 % 60 h = value // (60 * 60) - return '{0:02d}:{1:02d}:{2:02d}'.format(h, m, s) + return "{0:02d}:{1:02d}:{2:02d}".format(h, m, s) _60 = Integer(0, 59) _unsigned = Integer(0) class List(Validator): - """ Validates a list of strings + """Validates a list of strings""" - """ class Dialect(csv.Dialect): - """ Describes the properties of list option values. """ + """Describes the properties of list option values.""" + strict = True - delimiter = str(',') + delimiter = str(",") quotechar = str('"') doublequote = True - lineterminator = str('\n') + lineterminator = str("\n") skipinitialspace = True quoting = csv.QUOTE_MINIMAL def __init__(self, validator=None): if not (validator is None or isinstance(validator, Validator)): - raise ValueError(f'Expected a Validator instance or None for validator, not {repr(validator)}') + raise ValueError( + f"Expected a Validator instance or None for validator, not {repr(validator)}" + ) self._validator = validator def __call__(self, value): - if value is None or isinstance(value, list): return value @@ -312,7 +339,7 @@ def __call__(self, value): for index, item in enumerate(value): value[index] = self._validator(item) except ValueError as error: - raise ValueError(f'Could not convert item {index}: {error}') + raise ValueError(f"Could not convert item {index}: {error}") return value @@ -325,32 +352,35 @@ def format(self, value): class Map(Validator): - """ Validates map option values. + """Validates map option values.""" - """ def __init__(self, **kwargs): self.membership = kwargs def __call__(self, value): - if value is None: return None value = str(value) if value not in self.membership: - raise ValueError(f'Unrecognized value: {value}') + raise ValueError(f"Unrecognized value: {value}") return self.membership[value] def format(self, value): - return None if value is None else list(self.membership.keys())[list(self.membership.values()).index(value)] + return ( + None + if value is None + else list(self.membership.keys())[ + list(self.membership.values()).index(value) + ] + ) class Match(Validator): - """ Validates that a value matches a regular expression pattern. + """Validates that a value matches a regular expression pattern.""" - """ def __init__(self, name, pattern, flags=0): self.name = str(name) self.pattern = re.compile(pattern, flags) @@ -360,7 +390,7 @@ def __call__(self, value): return None value = str(value) if self.pattern.match(value) is None: - raise ValueError(f'Expected {self.name}, not {json_encode_string(value)}') + raise ValueError(f"Expected {self.name}, not {json_encode_string(value)}") return value def format(self, value): @@ -368,16 +398,15 @@ def format(self, value): class OptionName(Validator): - """ Validates option names. + """Validates option names.""" - """ - pattern = re.compile(r'''(?=\w)[^\d]\w*$''', re.UNICODE) + pattern = re.compile(r"""(?=\w)[^\d]\w*$""", re.UNICODE) def __call__(self, value): if value is not None: value = str(value) if OptionName.pattern.match(value) is None: - raise ValueError(f'Illegal characters in option name: {value}') + raise ValueError(f"Illegal characters in option name: {value}") return value def format(self, value): @@ -385,16 +414,15 @@ def format(self, value): class RegularExpression(Validator): - """ Validates regular expression option values. + """Validates regular expression option values.""" - """ def __call__(self, value): if value is None: return None try: value = re.compile(str(value)) except re.error as error: - raise ValueError(f'{str(error).capitalize()}: {value}') + raise ValueError(f"{str(error).capitalize()}: {value}") return value def format(self, value): @@ -402,9 +430,8 @@ def format(self, value): class Set(Validator): - """ Validates set option values. + """Validates set option values.""" - """ def __init__(self, *args): self.membership = set(args) @@ -413,11 +440,22 @@ def __call__(self, value): return None value = str(value) if value not in self.membership: - raise ValueError(f'Unrecognized value: {value}') + raise ValueError(f"Unrecognized value: {value}") return value def format(self, value): return self.__call__(value) -__all__ = ['Boolean', 'Code', 'Duration', 'File', 'Integer', 'Float', 'List', 'Map', 'RegularExpression', 'Set'] +__all__ = [ + "Boolean", + "Code", + "Duration", + "File", + "Integer", + "Float", + "List", + "Map", + "RegularExpression", + "Set", +] diff --git a/splunklib/six.py b/splunklib/six.py index d13e50c93..4d9448111 100644 --- a/splunklib/six.py +++ b/splunklib/six.py @@ -38,15 +38,15 @@ PY34 = sys.version_info[0:2] >= (3, 4) if PY3: - string_types = str, - integer_types = int, - class_types = type, + string_types = (str,) + integer_types = (int,) + class_types = (type,) text_type = str binary_type = bytes MAXSIZE = sys.maxsize else: - string_types = basestring, + string_types = (basestring,) integer_types = (int, long) class_types = (type, types.ClassType) text_type = unicode @@ -58,9 +58,9 @@ else: # It's possible to have sizeof(long) != sizeof(Py_ssize_t). class X(object): - def __len__(self): return 1 << 31 + try: len(X()) except OverflowError: @@ -84,7 +84,6 @@ def _import_module(name): class _LazyDescr(object): - def __init__(self, name): self.name = name @@ -101,7 +100,6 @@ def __get__(self, obj, tp): class MovedModule(_LazyDescr): - def __init__(self, name, old, new=None): super(MovedModule, self).__init__(name) if PY3: @@ -122,7 +120,6 @@ def __getattr__(self, attr): class _LazyModule(types.ModuleType): - def __init__(self, name): super(_LazyModule, self).__init__(name) self.__doc__ = self.__class__.__doc__ @@ -137,7 +134,6 @@ def __dir__(self): class MovedAttribute(_LazyDescr): - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): super(MovedAttribute, self).__init__(name) if PY3: @@ -162,7 +158,6 @@ def _resolve(self): class _SixMetaPathImporter(object): - """ A meta path importer to import six.moves and its submodules. @@ -221,21 +216,25 @@ def get_code(self, fullname): Required, if is_package is implemented""" self.__get_module(fullname) # eventually raises ImportError return None + get_source = get_code # same as get_code + _importer = _SixMetaPathImporter(__name__) class _MovedItems(_LazyModule): - """Lazy loading of moved objects""" + __path__ = [] # mark as package _moved_attributes = [ MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), + MovedAttribute( + "filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse" + ), MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), MovedAttribute("intern", "__builtin__", "sys"), MovedAttribute("map", "itertools", "builtins", "imap", "map"), @@ -243,7 +242,9 @@ class _MovedItems(_LazyModule): MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), MovedAttribute("getoutput", "commands", "subprocess"), MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), + MovedAttribute( + "reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload" + ), MovedAttribute("reduce", "__builtin__", "functools"), MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), MovedAttribute("StringIO", "StringIO", "io"), @@ -252,14 +253,24 @@ class _MovedItems(_LazyModule): MovedAttribute("UserString", "UserString", "collections"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), + MovedAttribute( + "zip_longest", "itertools", "itertools", "izip_longest", "zip_longest" + ), MovedModule("builtins", "__builtin__"), MovedModule("configparser", "ConfigParser"), - MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"), + MovedModule( + "collections_abc", + "collections", + "collections.abc" if sys.version_info >= (3, 3) else "collections", + ), MovedModule("copyreg", "copy_reg"), MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"), + MovedModule( + "_dummy_thread", + "dummy_thread", + "_dummy_thread" if sys.version_info < (3, 9) else "_thread", + ), MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), MovedModule("http_cookies", "Cookie", "http.cookies"), MovedModule("html_entities", "htmlentitydefs", "html.entities"), @@ -268,7 +279,9 @@ class _MovedItems(_LazyModule): MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), + MovedModule( + "email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart" + ), MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), @@ -287,15 +300,12 @@ class _MovedItems(_LazyModule): MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", - "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", - "tkinter.commondialog"), + MovedModule("tkinter_colorchooser", "tkColorChooser", "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", "tkinter.commondialog"), MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), MovedModule("tkinter_font", "tkFont", "tkinter.font"), MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", - "tkinter.simpledialog"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", "tkinter.simpledialog"), MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), @@ -322,7 +332,6 @@ class _MovedItems(_LazyModule): class Module_six_moves_urllib_parse(_LazyModule): - """Lazy loading of moved objects in six.moves.urllib_parse""" @@ -341,7 +350,9 @@ class Module_six_moves_urllib_parse(_LazyModule): MovedAttribute("quote_plus", "urllib", "urllib.parse"), MovedAttribute("unquote", "urllib", "urllib.parse"), MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), + MovedAttribute( + "unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes" + ), MovedAttribute("urlencode", "urllib", "urllib.parse"), MovedAttribute("splitquery", "urllib", "urllib.parse"), MovedAttribute("splittag", "urllib", "urllib.parse"), @@ -359,12 +370,14 @@ class Module_six_moves_urllib_parse(_LazyModule): Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes -_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", "moves.urllib.parse") +_importer._add_module( + Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), + "moves.urllib_parse", + "moves.urllib.parse", +) class Module_six_moves_urllib_error(_LazyModule): - """Lazy loading of moved objects in six.moves.urllib_error""" @@ -379,12 +392,14 @@ class Module_six_moves_urllib_error(_LazyModule): Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes -_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", "moves.urllib.error") +_importer._add_module( + Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), + "moves.urllib_error", + "moves.urllib.error", +) class Module_six_moves_urllib_request(_LazyModule): - """Lazy loading of moved objects in six.moves.urllib_request""" @@ -431,12 +446,14 @@ class Module_six_moves_urllib_request(_LazyModule): Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes -_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", "moves.urllib.request") +_importer._add_module( + Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), + "moves.urllib_request", + "moves.urllib.request", +) class Module_six_moves_urllib_response(_LazyModule): - """Lazy loading of moved objects in six.moves.urllib_response""" @@ -452,12 +469,14 @@ class Module_six_moves_urllib_response(_LazyModule): Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes -_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", "moves.urllib.response") +_importer._add_module( + Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), + "moves.urllib_response", + "moves.urllib.response", +) class Module_six_moves_urllib_robotparser(_LazyModule): - """Lazy loading of moved objects in six.moves.urllib_robotparser""" @@ -468,15 +487,20 @@ class Module_six_moves_urllib_robotparser(_LazyModule): setattr(Module_six_moves_urllib_robotparser, attr.name, attr) del attr -Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes +Module_six_moves_urllib_robotparser._moved_attributes = ( + _urllib_robotparser_moved_attributes +) -_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", "moves.urllib.robotparser") +_importer._add_module( + Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), + "moves.urllib_robotparser", + "moves.urllib.robotparser", +) class Module_six_moves_urllib(types.ModuleType): - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" + __path__ = [] # mark as package parse = _importer._get_module("moves.urllib_parse") error = _importer._get_module("moves.urllib_error") @@ -485,10 +509,12 @@ class Module_six_moves_urllib(types.ModuleType): robotparser = _importer._get_module("moves.urllib_robotparser") def __dir__(self): - return ['parse', 'error', 'request', 'response', 'robotparser'] + return ["parse", "error", "request", "response", "robotparser"] + -_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), - "moves.urllib") +_importer._add_module( + Module_six_moves_urllib(__name__ + ".moves.urllib"), "moves.urllib" +) def add_move(move): @@ -528,19 +554,24 @@ def remove_move(name): try: advance_iterator = next except NameError: + def advance_iterator(it): return it.next() + + next = advance_iterator try: callable = callable except NameError: + def callable(obj): return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) if PY3: + def get_unbound_function(unbound): return unbound @@ -551,6 +582,7 @@ def create_unbound_method(func, cls): Iterator = object else: + def get_unbound_function(unbound): return unbound.im_func @@ -561,13 +593,13 @@ def create_unbound_method(func, cls): return types.MethodType(func, None, cls) class Iterator(object): - def next(self): return type(self).__next__(self) callable = callable -_add_doc(get_unbound_function, - """Get the function out of a possibly unbound function""") +_add_doc( + get_unbound_function, """Get the function out of a possibly unbound function""" +) get_method_function = operator.attrgetter(_meth_func) @@ -579,6 +611,7 @@ def next(self): if PY3: + def iterkeys(d, **kw): return iter(d.keys(**kw)) @@ -597,6 +630,7 @@ def iterlists(d, **kw): viewitems = operator.methodcaller("items") else: + def iterkeys(d, **kw): return d.iterkeys(**kw) @@ -617,26 +651,30 @@ def iterlists(d, **kw): _add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") _add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, - "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc(iterlists, - "Return an iterator over the (key, [values]) pairs of a dictionary.") +_add_doc(iteritems, "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc( + iterlists, "Return an iterator over the (key, [values]) pairs of a dictionary." +) if PY3: + def b(s): return s.encode("latin-1") def u(s): return s + unichr = chr import struct + int2byte = struct.Struct(">B").pack del struct byte2int = operator.itemgetter(0) indexbytes = operator.getitem iterbytes = iter import io + StringIO = io.StringIO BytesIO = io.BytesIO del io @@ -650,12 +688,15 @@ def u(s): _assertRegex = "assertRegex" _assertNotRegex = "assertNotRegex" else: + def b(s): return s + # Workaround for standalone backslash def u(s): - return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") + return unicode(s.replace(r"\\", r"\\\\"), "unicode_escape") + unichr = unichr int2byte = chr @@ -664,8 +705,10 @@ def byte2int(bs): def indexbytes(buf, i): return ord(buf[i]) + iterbytes = functools.partial(itertools.imap, ord) import StringIO + StringIO = BytesIO = StringIO.StringIO _assertCountEqual = "assertItemsEqual" _assertRaisesRegex = "assertRaisesRegexp" @@ -706,6 +749,7 @@ def reraise(tp, value, tb=None): tb = None else: + def exec_(_code_, _globs_=None, _locs_=None): """Execute code in a namespace.""" if _globs_ is None: @@ -734,12 +778,14 @@ def exec_(_code_, _globs_=None, _locs_=None): value = None """) else: + def raise_from(value, from_value): raise value print_ = getattr(moves.builtins, "print", None) if print_ is None: + def print_(*args, **kwargs): """The new-style print function for Python 2.4 and 2.5.""" fp = kwargs.pop("file", sys.stdout) @@ -750,14 +796,17 @@ def write(data): if not isinstance(data, basestring): data = str(data) # If the file has an encoding, encode unicode with it. - if (isinstance(fp, file) and - isinstance(data, unicode) and - fp.encoding is not None): + if ( + isinstance(fp, file) + and isinstance(data, unicode) + and fp.encoding is not None + ): errors = getattr(fp, "errors", None) if errors is None: errors = "strict" data = data.encode(fp.encoding, errors) fp.write(data) + want_unicode = False sep = kwargs.pop("sep", None) if sep is not None: @@ -793,6 +842,8 @@ def write(data): write(sep) write(arg) write(end) + + if sys.version_info[:2] < (3, 3): _print = print_ @@ -803,6 +854,7 @@ def print_(*args, **kwargs): if flush and fp is not None: fp.flush() + _add_doc(reraise, """Reraise an exception.""") if sys.version_info[0:2] < (3, 4): @@ -811,9 +863,12 @@ def print_(*args, **kwargs): # attribute on ``wrapper`` object and it doesn't raise an error if any of # the attributes mentioned in ``assigned`` and ``updated`` are missing on # ``wrapped`` object. - def _update_wrapper(wrapper, wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): + def _update_wrapper( + wrapper, + wrapped, + assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES, + ): for attr in assigned: try: value = getattr(wrapped, attr) @@ -825,12 +880,18 @@ def _update_wrapper(wrapper, wrapped, getattr(wrapper, attr).update(getattr(wrapped, attr, {})) wrapper.__wrapped__ = wrapped return wrapper + _update_wrapper.__doc__ = functools.update_wrapper.__doc__ - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - return functools.partial(_update_wrapper, wrapped=wrapped, - assigned=assigned, updated=updated) + def wraps( + wrapped, + assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES, + ): + return functools.partial( + _update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated + ) + wraps.__doc__ = functools.wraps.__doc__ else: @@ -839,18 +900,18 @@ def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, def with_metaclass(meta, *bases): """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy # metaclass for one level of class instantiation that replaces itself with # the actual metaclass. class metaclass(type): - def __new__(cls, name, this_bases, d): if sys.version_info[:2] >= (3, 7): # This version introduced PEP 560 that requires a bit # of extra care (we mimic what is done by __build_class__). resolved_bases = types.resolve_bases(bases) if resolved_bases is not bases: - d['__orig_bases__'] = bases + d["__orig_bases__"] = bases else: resolved_bases = bases return meta(name, resolved_bases, d) @@ -858,28 +919,31 @@ def __new__(cls, name, this_bases, d): @classmethod def __prepare__(cls, name, this_bases): return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) + + return type.__new__(metaclass, "temporary_class", (), {}) def add_metaclass(metaclass): """Class decorator for creating a class with a metaclass.""" + def wrapper(cls): orig_vars = cls.__dict__.copy() - slots = orig_vars.get('__slots__') + slots = orig_vars.get("__slots__") if slots is not None: if isinstance(slots, str): slots = [slots] for slots_var in slots: orig_vars.pop(slots_var) - orig_vars.pop('__dict__', None) - orig_vars.pop('__weakref__', None) - if hasattr(cls, '__qualname__'): - orig_vars['__qualname__'] = cls.__qualname__ + orig_vars.pop("__dict__", None) + orig_vars.pop("__weakref__", None) + if hasattr(cls, "__qualname__"): + orig_vars["__qualname__"] = cls.__qualname__ return metaclass(cls.__name__, cls.__bases__, orig_vars) + return wrapper -def ensure_binary(s, encoding='utf-8', errors='strict'): +def ensure_binary(s, encoding="utf-8", errors="strict"): """Coerce **s** to six.binary_type. For Python 2: @@ -898,7 +962,7 @@ def ensure_binary(s, encoding='utf-8', errors='strict'): raise TypeError("not expecting type '%s'" % type(s)) -def ensure_str(s, encoding='utf-8', errors='strict'): +def ensure_str(s, encoding="utf-8", errors="strict"): """Coerce *s* to `str`. For Python 2: @@ -918,7 +982,7 @@ def ensure_str(s, encoding='utf-8', errors='strict'): return s -def ensure_text(s, encoding='utf-8', errors='strict'): +def ensure_text(s, encoding="utf-8", errors="strict"): """Coerce *s* to six.text_type. For Python 2: @@ -946,12 +1010,13 @@ def python_2_unicode_compatible(klass): returning text and apply this decorator to the class. """ if PY2: - if '__str__' not in klass.__dict__: - raise ValueError("@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % - klass.__name__) + if "__str__" not in klass.__dict__: + raise ValueError( + "@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % klass.__name__ + ) klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + klass.__str__ = lambda self: self.__unicode__().encode("utf-8") return klass @@ -971,8 +1036,10 @@ def python_2_unicode_compatible(klass): # be floating around. Therefore, we can't use isinstance() to check for # the six meta path importer, since the other six instance will have # inserted an importer with different class. - if (type(importer).__name__ == "_SixMetaPathImporter" and - importer.name == __name__): + if ( + type(importer).__name__ == "_SixMetaPathImporter" + and importer.name == __name__ + ): del sys.meta_path[i] break del i, importer @@ -981,13 +1048,18 @@ def python_2_unicode_compatible(klass): import warnings + def deprecated(message): - def deprecated_decorator(func): - def deprecated_func(*args, **kwargs): - warnings.warn("{} is a deprecated function. {}".format(func.__name__, message), - category=DeprecationWarning, - stacklevel=2) - warnings.simplefilter('default', DeprecationWarning) - return func(*args, **kwargs) - return deprecated_func - return deprecated_decorator \ No newline at end of file + def deprecated_decorator(func): + def deprecated_func(*args, **kwargs): + warnings.warn( + "{} is a deprecated function. {}".format(func.__name__, message), + category=DeprecationWarning, + stacklevel=2, + ) + warnings.simplefilter("default", DeprecationWarning) + return func(*args, **kwargs) + + return deprecated_func + + return deprecated_decorator diff --git a/splunklib/utils.py b/splunklib/utils.py index db9c31267..9b1631dea 100644 --- a/splunklib/utils.py +++ b/splunklib/utils.py @@ -12,14 +12,13 @@ # License for the specific language governing permissions and limitations # under the License. -"""The **splunklib.utils** File for utility functions. -""" +"""The **splunklib.utils** File for utility functions.""" -def ensure_binary(s, encoding='utf-8', errors='strict'): +def ensure_binary(s, encoding="utf-8", errors="strict"): """ - - `str` -> encoded to `bytes` - - `bytes` -> `bytes` + - `str` -> encoded to `bytes` + - `bytes` -> `bytes` """ if isinstance(s, str): return s.encode(encoding, errors) @@ -30,10 +29,10 @@ def ensure_binary(s, encoding='utf-8', errors='strict'): raise TypeError(f"not expecting type '{type(s)}'") -def ensure_str(s, encoding='utf-8', errors='strict'): +def ensure_str(s, encoding="utf-8", errors="strict"): """ - - `str` -> `str` - - `bytes` -> decoded to `str` + - `str` -> `str` + - `bytes` -> decoded to `str` """ if isinstance(s, bytes): return s.decode(encoding, errors) diff --git a/tests/modularinput/modularinput_testlib.py b/tests/modularinput/modularinput_testlib.py index 4bf0df13f..5abc1edde 100644 --- a/tests/modularinput/modularinput_testlib.py +++ b/tests/modularinput/modularinput_testlib.py @@ -19,9 +19,12 @@ import sys import unittest -sys.path.insert(0, os.path.join('../../splunklib', '..')) +sys.path.insert(0, os.path.join("../../splunklib", "..")) from splunklib.modularinput.utils import xml_compare, parse_xml_data, parse_parameters + def data_open(filepath): - return io.open(os.path.join(os.path.dirname(os.path.abspath(__file__)), filepath), 'rb') + return io.open( + os.path.join(os.path.dirname(os.path.abspath(__file__)), filepath), "rb" + ) diff --git a/tests/modularinput/test_event.py b/tests/modularinput/test_event.py index 35e9c09cd..ec98fa12c 100644 --- a/tests/modularinput/test_event.py +++ b/tests/modularinput/test_event.py @@ -32,6 +32,7 @@ def test_event_without_enough_fields_fails(capsys): event = Event() event.write_to(sys.stdout) + def test_xml_of_event_with_minimal_configuration(capsys): """Generate XML from an event object with a small number of fields, and see if it matches what we expect.""" @@ -39,7 +40,7 @@ def test_xml_of_event_with_minimal_configuration(capsys): event = Event( data="This is a test of the emergency broadcast system.", stanza="fubar", - time="%.3f" % 1372187084.000 + time="%.3f" % 1372187084.000, ) event.write_to(sys.stdout) @@ -51,6 +52,7 @@ def test_xml_of_event_with_minimal_configuration(capsys): assert xml_compare(expected, constructed) + def test_xml_of_event_with_more_configuration(capsys): """Generate XML from an even object with all fields set, see if it matches what we expect""" @@ -64,7 +66,7 @@ def test_xml_of_event_with_more_configuration(capsys): source="hilda", sourcetype="misc", done=True, - unbroken=True + unbroken=True, ) event.write_to(sys.stdout) @@ -76,6 +78,7 @@ def test_xml_of_event_with_more_configuration(capsys): assert xml_compare(expected, constructed) + def test_writing_events_on_event_writer(capsys): """Write a pair of events with an EventWriter, and ensure that they are being encoded immediately and correctly onto the output stream""" @@ -91,7 +94,7 @@ def test_writing_events_on_event_writer(capsys): source="hilda", sourcetype="misc", done=True, - unbroken=True + unbroken=True, ) ew.write_event(e) @@ -116,6 +119,7 @@ def test_writing_events_on_event_writer(capsys): assert xml_compare(expected, found) + def test_error_in_event_writer(): """An event which cannot write itself onto an output stream (such as because it doesn't have a data field set) @@ -125,7 +129,11 @@ def test_error_in_event_writer(): e = Event() with pytest.raises(ValueError) as excinfo: ew.write_event(e) - assert str(excinfo.value) == "Events must have at least the data field set to be written to XML." + assert ( + str(excinfo.value) + == "Events must have at least the data field set to be written to XML." + ) + def test_logging_errors_with_event_writer(capsys): """Check that the log method on EventWriter produces the @@ -138,6 +146,7 @@ def test_logging_errors_with_event_writer(capsys): captured = capsys.readouterr() assert captured.err == "ERROR Something happened!\n" + def test_write_xml_is_sane(capsys): """Check that EventWriter.write_xml_document writes sensible XML to the output stream.""" @@ -169,12 +178,12 @@ def test_log_exception(): # Remove paths and line err = re.sub(r'File "[^"]+', 'File "...', err.getvalue()) - err = re.sub(r'line \d+', 'line 123', err) + err = re.sub(r"line \d+", "line 123", err) # One line assert err == ( - 'ERROR ex1 - Traceback (most recent call last): ' + "ERROR ex1 - Traceback (most recent call last): " ' File "...", line 123, in test_log_exception ' - ' raise exc ' - 'Exception: Something happened! ' + " raise exc " + "Exception: Something happened! " ) diff --git a/tests/modularinput/test_input_definition.py b/tests/modularinput/test_input_definition.py index 93601b352..a84cfa92a 100644 --- a/tests/modularinput/test_input_definition.py +++ b/tests/modularinput/test_input_definition.py @@ -19,7 +19,6 @@ class InputDefinitionTestCase(unittest.TestCase): - def test_parse_inputdef_with_zero_inputs(self): """Check parsing of XML that contains only metadata""" @@ -30,7 +29,7 @@ def test_parse_inputdef_with_zero_inputs(self): "server_host": "tiny", "server_uri": "https://127.0.0.1:8089", "checkpoint_dir": "/some/dir", - "session_key": "123102983109283019283" + "session_key": "123102983109283019283", } self.assertEqual(found, expectedDefinition) @@ -45,14 +44,14 @@ def test_parse_inputdef_with_two_inputs(self): "server_host": "tiny", "server_uri": "https://127.0.0.1:8089", "checkpoint_dir": "/some/dir", - "session_key": "123102983109283019283" + "session_key": "123102983109283019283", } expectedDefinition.inputs["foobar://aaa"] = { "__app": "search", "param1": "value1", "param2": "value2", "disabled": "0", - "index": "default" + "index": "default", } expectedDefinition.inputs["foobar://bbb"] = { "__app": "my_app", @@ -61,7 +60,7 @@ def test_parse_inputdef_with_two_inputs(self): "disabled": "0", "index": "default", "multiValue": ["value1", "value2"], - "multiValue2": ["value3", "value4"] + "multiValue2": ["value3", "value4"], } self.assertEqual(expectedDefinition, found) diff --git a/tests/modularinput/test_scheme.py b/tests/modularinput/test_scheme.py index 7e7ddcc9c..51edb12bf 100644 --- a/tests/modularinput/test_scheme.py +++ b/tests/modularinput/test_scheme.py @@ -50,7 +50,7 @@ def test_generate_xml_from_scheme(self): validation="is_pos_int('some_name')", data_type=Argument.data_type_number, required_on_edit=True, - required_on_create=True + required_on_create=True, ) scheme.add_argument(arg2) @@ -80,13 +80,15 @@ def test_generate_xml_from_scheme_with_arg_title(self): data_type=Argument.data_type_number, required_on_edit=True, required_on_create=True, - title="Argument for ``test_scheme``" + title="Argument for ``test_scheme``", ) scheme.add_argument(arg2) constructed = scheme.to_xml() - expected = ET.parse(data_open("data/scheme_without_defaults_and_argument_title.xml")).getroot() + expected = ET.parse( + data_open("data/scheme_without_defaults_and_argument_title.xml") + ).getroot() self.assertEqual("Argument for ``test_scheme``", arg2.title) self.assertTrue(xml_compare(expected, constructed)) @@ -113,7 +115,7 @@ def test_generate_xml_from_argument(self): validation="is_pos_int('some_name')", data_type=Argument.data_type_boolean, required_on_edit="true", - required_on_create="true" + required_on_create="true", ) root = ET.Element("") @@ -123,5 +125,6 @@ def test_generate_xml_from_argument(self): self.assertTrue(xml_compare(expected, constructed)) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/modularinput/test_script.py b/tests/modularinput/test_script.py index 49c259725..1911caf21 100644 --- a/tests/modularinput/test_script.py +++ b/tests/modularinput/test_script.py @@ -49,7 +49,7 @@ def test_scheme_properly_generated_by_script(capsys): class NewScript(Script): def get_scheme(self): scheme = Scheme("abcd") - scheme.description = "\uC3BC and \uC3B6 and <&> f\u00FCr" + scheme.description = "\uc3bc and \uc3b6 and <&> f\u00fcr" scheme.streaming_mode = scheme.streaming_mode_simple scheme.use_external_validation = False scheme.use_single_instance = True @@ -58,7 +58,7 @@ def get_scheme(self): scheme.add_argument(arg1) arg2 = Argument("arg2") - arg2.description = "\uC3BC and \uC3B6 and <&> f\u00FCr" + arg2.description = "\uc3bc and \uc3b6 and <&> f\u00fcr" arg2.data_type = Argument.data_type_number arg2.required_on_create = True arg2.required_on_edit = True @@ -173,7 +173,7 @@ def stream_events(self, inputs, ew): source="hilda", sourcetype="misc", done=True, - unbroken=True + unbroken=True, ) ew.write_event(event) @@ -198,7 +198,7 @@ def stream_events(self, inputs, ew): def test_service_property(capsys): - """ Check that Script.service returns a valid Service instance as soon + """Check that Script.service returns a valid Service instance as soon as the stream_events method is called, but not before. """ @@ -213,7 +213,7 @@ def get_scheme(self): return None def stream_events(self, inputs, ew): - self.authority_uri = inputs.metadata['server_uri'] + self.authority_uri = inputs.metadata["server_uri"] script = NewScript() with data_open("data/conf_with_2_inputs.xml") as input_configuration: @@ -221,8 +221,7 @@ def stream_events(self, inputs, ew): assert script.service is None - return_value = script.run_script( - [TEST_SCRIPT_PATH], ew, input_configuration) + return_value = script.run_script([TEST_SCRIPT_PATH], ew, input_configuration) output = capsys.readouterr() assert return_value == 0 @@ -251,16 +250,16 @@ def stream_events(self, inputs, ew): # Remove paths and line numbers err = re.sub(r'File "[^"]+', 'File "...', err.getvalue()) - err = re.sub(r'line \d+', 'line 123', err) - err = re.sub(r' +~+\^+', '', err) + err = re.sub(r"line \d+", "line 123", err) + err = re.sub(r" +~+\^+", "", err) assert out.getvalue() == "" assert err == ( - 'ERROR Some error - ' - 'Traceback (most recent call last): ' + "ERROR Some error - " + "Traceback (most recent call last): " ' File "...", line 123, in run_script ' - ' self.stream_events(self._input_definition, event_writer) ' + " self.stream_events(self._input_definition, event_writer) " ' File "...", line 123, in stream_events ' ' raise RuntimeError("Some error") ' - 'RuntimeError: Some error ' + "RuntimeError: Some error " ) diff --git a/tests/modularinput/test_validation_definition.py b/tests/modularinput/test_validation_definition.py index 1b71a2206..a267902d8 100644 --- a/tests/modularinput/test_validation_definition.py +++ b/tests/modularinput/test_validation_definition.py @@ -30,7 +30,7 @@ def test_validation_definition_parse(self): "server_uri": "https://127.0.0.1:8089", "checkpoint_dir": "/opt/splunk/var/lib/splunk/modinputs", "session_key": "123102983109283019283", - "name": "aaa" + "name": "aaa", } expected.parameters = { "param1": "value1", @@ -38,7 +38,7 @@ def test_validation_definition_parse(self): "disabled": "0", "index": "default", "multiValue": ["value1", "value2"], - "multiValue2": ["value3", "value4"] + "multiValue2": ["value3", "value4"], } self.assertEqual(expected, found) diff --git a/tests/searchcommands/__init__.py b/tests/searchcommands/__init__.py index 41d8d0668..1cbd2bb8f 100644 --- a/tests/searchcommands/__init__.py +++ b/tests/searchcommands/__init__.py @@ -26,11 +26,13 @@ def rebase_environment(name): - environment.app_root = path.join(package_directory, 'apps', name) + environment.app_root = path.join(package_directory, "apps", name) logging.Logger.manager.loggerDict.clear() del logging.root.handlers[:] - environment.splunklib_logger, environment.logging_configuration = environment.configure_logging('splunklib') + environment.splunklib_logger, environment.logging_configuration = ( + environment.configure_logging("splunklib") + ) searchcommands.logging_configuration = environment.logging_configuration searchcommands.splunklib_logger = environment.splunklib_logger searchcommands.app_root = environment.app_root diff --git a/tests/searchcommands/chunked_data_stream.py b/tests/searchcommands/chunked_data_stream.py index 02d890af1..3deb440e3 100644 --- a/tests/searchcommands/chunked_data_stream.py +++ b/tests/searchcommands/chunked_data_stream.py @@ -12,8 +12,7 @@ def __init__(self, version, meta, data): self.version = ensure_str(version) self.meta = json.loads(meta) dialect = splunklib.searchcommands.internals.CsvDialect - self.data = csv.DictReader(io.StringIO(data.decode("utf-8")), - dialect=dialect) + self.data = csv.DictReader(io.StringIO(data.decode("utf-8")), dialect=dialect) class ChunkedDataStreamIter(collections.abc.Iterator): @@ -42,12 +41,12 @@ def __init__(self, stream): def read_chunk(self): header = self.stream.readline() - while len(header) > 0 and header.strip() == b'': + while len(header) > 0 and header.strip() == b"": header = self.stream.readline() # Skip empty lines if len(header) == 0: raise EOFError - version, meta, data = header.rstrip().split(b',') + version, meta, data = header.rstrip().split(b",") metabytes = self.stream.read(int(meta)) databytes = self.stream.read(int(data)) return Chunk(version, metabytes, databytes) @@ -56,35 +55,39 @@ def read_chunk(self): def build_chunk(keyval, data=None): metadata = ensure_binary(json.dumps(keyval)) data_output = _build_data_csv(data) - return b"chunked 1.0,%d,%d\n%s%s" % (len(metadata), len(data_output), metadata, data_output) + return b"chunked 1.0,%d,%d\n%s%s" % ( + len(metadata), + len(data_output), + metadata, + data_output, + ) def build_empty_searchinfo(): return { - 'earliest_time': 0, - 'latest_time': 0, - 'search': "", - 'dispatch_dir': "", - 'sid': "", - 'args': [], - 'splunk_version': "42.3.4", + "earliest_time": 0, + "latest_time": 0, + "search": "", + "dispatch_dir": "", + "sid": "", + "args": [], + "splunk_version": "42.3.4", } def build_getinfo_chunk(): - return build_chunk({ - 'action': 'getinfo', - 'preview': False, - 'searchinfo': build_empty_searchinfo()}) + return build_chunk( + {"action": "getinfo", "preview": False, "searchinfo": build_empty_searchinfo()} + ) def build_data_chunk(data, finished=True): - return build_chunk({'action': 'execute', 'finished': finished}, data) + return build_chunk({"action": "execute", "finished": finished}, data) def _build_data_csv(data): if data is None: - return b'' + return b"" if isinstance(data, bytes): return data csvout = io.StringIO() @@ -92,8 +95,9 @@ def _build_data_csv(data): headers = set() for datum in data: headers.update(datum.keys()) - writer = csv.DictWriter(csvout, headers, - dialect=splunklib.searchcommands.internals.CsvDialect) + writer = csv.DictWriter( + csvout, headers, dialect=splunklib.searchcommands.internals.CsvDialect + ) writer.writeheader() for datum in data: writer.writerow(datum) diff --git a/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py b/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py index fafbe46f9..9f43d2581 100644 --- a/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py +++ b/tests/searchcommands/test_apps/eventing_app/bin/eventingcsc.py @@ -15,10 +15,16 @@ # License for the specific language governing permissions and limitations # under the License. -import os,sys +import os, sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "lib")) -from splunklib.searchcommands import dispatch, EventingCommand, Configuration, Option, validators +from splunklib.searchcommands import ( + dispatch, + EventingCommand, + Configuration, + Option, + validators, +) @Configuration() @@ -35,9 +41,10 @@ class EventingCSC(EventingCommand): """ status = Option( - doc='''**Syntax:** **status=**** - **Description:** record having same status value will be returned.''', - require=True) + doc="""**Syntax:** **status=**** + **Description:** record having same status value will be returned.""", + require=True, + ) def transform(self, records): for record in records: diff --git a/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py b/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py index 4fe3e765a..42d5aff77 100644 --- a/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py +++ b/tests/searchcommands/test_apps/generating_app/bin/generatingcsc.py @@ -19,7 +19,13 @@ import time sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "lib")) -from splunklib.searchcommands import dispatch, GeneratingCommand, Configuration, Option, validators +from splunklib.searchcommands import ( + dispatch, + GeneratingCommand, + Configuration, + Option, + validators, +) @Configuration() @@ -39,8 +45,8 @@ class GeneratingCSC(GeneratingCommand): def generate(self): self.logger.debug("Generating %s events" % self.count) for i in range(1, self.count + 1): - text = f'Test Event {i}' - yield {'_time': time.time(), 'event_no': i, '_raw': text} + text = f"Test Event {i}" + yield {"_time": time.time(), "event_no": i, "_raw": text} dispatch(GeneratingCSC, sys.argv, sys.stdin, sys.stdout, __name__) diff --git a/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py b/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py index 477f5fe20..145df1b13 100644 --- a/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py +++ b/tests/searchcommands/test_apps/reporting_app/bin/reportingcsc.py @@ -15,10 +15,16 @@ # License for the specific language governing permissions and limitations # under the License. -import os,sys +import os, sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "lib")) -from splunklib.searchcommands import dispatch, ReportingCommand, Configuration, Option, validators +from splunklib.searchcommands import ( + dispatch, + ReportingCommand, + Configuration, + Option, + validators, +) @Configuration(requires_preop=True) diff --git a/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py b/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py index 8ee2c91eb..aa92cd456 100644 --- a/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py +++ b/tests/searchcommands/test_apps/streaming_app/bin/streamingcsc.py @@ -15,10 +15,16 @@ # License for the specific language governing permissions and limitations # under the License. -import os,sys +import os, sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "lib")) -from splunklib.searchcommands import dispatch, StreamingCommand, Configuration, Option, validators +from splunklib.searchcommands import ( + dispatch, + StreamingCommand, + Configuration, + Option, + validators, +) @Configuration() diff --git a/tests/searchcommands/test_builtin_options.py b/tests/searchcommands/test_builtin_options.py index 174baed07..849f4b79b 100644 --- a/tests/searchcommands/test_builtin_options.py +++ b/tests/searchcommands/test_builtin_options.py @@ -31,13 +31,16 @@ from tests.searchcommands import rebase_environment, package_directory + # portable log level names # https://stackoverflow.com/a/49724281 def level_names(): - return [logging.getLevelName(v) for v in - sorted(getattr(logging, '_levelToName', None) - or logging._levelNames) - if getattr(v, "real", 0)] + return [ + logging.getLevelName(v) + for v in sorted(getattr(logging, "_levelToName", None) or logging._levelNames) + if getattr(v, "real", 0) + ] + @Configuration() class StubbedSearchCommand(SearchCommand): @@ -49,86 +52,104 @@ def fix_up(cls, command_class): @pytest.mark.smoke class TestBuiltinOptions(TestCase): - def setUp(self): TestCase.setUp(self) def test_logging_configuration(self): - # Test that logging is properly initialized when there is no logging configuration file - rebase_environment('app_without_logging_configuration') + rebase_environment("app_without_logging_configuration") self.assertIsNone(environment.logging_configuration) - self.assertTrue(any(isinstance(h, logging.StreamHandler) for h in logging.root.handlers)) - self.assertTrue('splunklib' in logging.Logger.manager.loggerDict) - self.assertEqual(environment.splunklib_logger, logging.Logger.manager.loggerDict['splunklib']) + self.assertTrue( + any(isinstance(h, logging.StreamHandler) for h in logging.root.handlers) + ) + self.assertTrue("splunklib" in logging.Logger.manager.loggerDict) + self.assertEqual( + environment.splunklib_logger, logging.Logger.manager.loggerDict["splunklib"] + ) self.assertIsInstance(environment.splunklib_logger, logging.Logger) command = StubbedSearchCommand() - self.assertIs(command.logger, logging.getLogger('StubbedSearchCommand')) + self.assertIs(command.logger, logging.getLogger("StubbedSearchCommand")) self.assertEqual(len(command.logger.handlers), 0) self.assertIsNone(command.logging_configuration) self.assertIs(command.logger.root, logging.root) - root_handler = next(h for h in logging.root.handlers if isinstance(h, logging.StreamHandler)) + root_handler = next( + h for h in logging.root.handlers if isinstance(h, logging.StreamHandler) + ) self.assertIsInstance(root_handler, logging.StreamHandler) self.assertEqual(root_handler.stream, sys.stderr) - self.assertEqual(command.logging_level, logging.getLevelName(logging.root.level)) + self.assertEqual( + command.logging_level, logging.getLevelName(logging.root.level) + ) root_handler.stream = StringIO() - message = 'Test that output is directed to stderr without formatting' + message = "Test that output is directed to stderr without formatting" command.logger.warning(message) - self.assertEqual(root_handler.stream.getvalue(), message + '\n') + self.assertEqual(root_handler.stream.getvalue(), message + "\n") # A search command loads {local,default}/logging.conf when it is available - rebase_environment('app_with_logging_configuration') + rebase_environment("app_with_logging_configuration") command = StubbedSearchCommand() - self.assertEqual(command.logging_configuration, os.path.join(environment.app_root, 'default', 'logging.conf')) - self.assertIs(command.logger, logging.getLogger('StubbedSearchCommand')) + self.assertEqual( + command.logging_configuration, + os.path.join(environment.app_root, "default", "logging.conf"), + ) + self.assertIs(command.logger, logging.getLogger("StubbedSearchCommand")) # Setting logging_configuration loads a new logging configuration file relative to the app root - command.logging_configuration = 'alternative-logging.conf' + command.logging_configuration = "alternative-logging.conf" self.assertEqual( - command.logging_configuration, os.path.join(environment.app_root, 'default', 'alternative-logging.conf')) - self.assertIs(command.logger, logging.getLogger('StubbedSearchCommand')) + command.logging_configuration, + os.path.join(environment.app_root, "default", "alternative-logging.conf"), + ) + self.assertIs(command.logger, logging.getLogger("StubbedSearchCommand")) # Setting logging_configuration loads a new logging configuration file on an absolute path - app_root_logging_configuration = os.path.join(environment.app_root, 'logging.conf') + app_root_logging_configuration = os.path.join( + environment.app_root, "logging.conf" + ) command.logging_configuration = app_root_logging_configuration self.assertEqual(command.logging_configuration, app_root_logging_configuration) - self.assertIs(command.logger, logging.getLogger('StubbedSearchCommand')) + self.assertIs(command.logger, logging.getLogger("StubbedSearchCommand")) # logging_configuration raises a value error, if a non-existent logging configuration file is provided try: - command.logging_configuration = 'foo' + command.logging_configuration = "foo" except ValueError: pass except BaseException as e: - self.fail(f'Expected ValueError, but {type(e)} was raised') + self.fail(f"Expected ValueError, but {type(e)} was raised") else: - self.fail(f'Expected ValueError, but logging_configuration={command.logging_configuration}') + self.fail( + f"Expected ValueError, but logging_configuration={command.logging_configuration}" + ) try: - command.logging_configuration = os.path.join(package_directory, 'non-existent.logging.conf') + command.logging_configuration = os.path.join( + package_directory, "non-existent.logging.conf" + ) except ValueError: pass except BaseException as e: - self.fail(f'Expected ValueError, but {type(e)} was raised') + self.fail(f"Expected ValueError, but {type(e)} was raised") else: - self.fail(f'Expected ValueError, but logging_configuration={command.logging_configuration}') + self.fail( + f"Expected ValueError, but logging_configuration={command.logging_configuration}" + ) def test_logging_level(self): - - rebase_environment('app_without_logging_configuration') + rebase_environment("app_without_logging_configuration") command = StubbedSearchCommand() warning = logging.getLevelName(logging.WARNING) @@ -145,31 +166,37 @@ def test_logging_level(self): if isinstance(level, int): command.logging_level = level level_name = logging.getLevelName(level) - self.assertEqual(command.logging_level, warning if level_name == notset else level_name) + self.assertEqual( + command.logging_level, + warning if level_name == notset else level_name, + ) else: level_name = logging.getLevelName(logging.getLevelName(level)) for variant in level, level.lower(), level.capitalize(): command.logging_level = variant - self.assertEqual(command.logging_level, warning if level_name == notset else level_name) + self.assertEqual( + command.logging_level, + warning if level_name == notset else level_name, + ) # logging_level accepts any numeric value for level in 999, 999.999: command.logging_level = level - self.assertEqual(command.logging_level, 'Level 999') + self.assertEqual(command.logging_level, "Level 999") # logging_level raises a value error for unknown logging level names current_value = command.logging_level try: - command.logging_level = 'foo' + command.logging_level = "foo" except ValueError: pass except BaseException as e: - self.fail(f'Expected ValueError, but {type(e)} was raised') + self.fail(f"Expected ValueError, but {type(e)} was raised") else: - self.fail(f'Expected ValueError, but logging_level={command.logging_level}') + self.fail(f"Expected ValueError, but logging_level={command.logging_level}") self.assertEqual(command.logging_level, current_value) @@ -180,18 +207,23 @@ def test_show_configuration(self): self._test_boolean_option(StubbedSearchCommand.show_configuration) def _test_boolean_option(self, option): - - rebase_environment('app_without_logging_configuration') + rebase_environment("app_without_logging_configuration") command = StubbedSearchCommand() # show_configuration accepts Splunk boolean values boolean_values = { - '0': False, '1': True, - 'f': False, 't': True, - 'n': False, 'y': True, - 'no': False, 'yes': True, - 'false': False, 'true': True} + "0": False, + "1": True, + "f": False, + "t": True, + "n": False, + "y": True, + "no": False, + "yes": True, + "false": False, + "true": True, + } for value in boolean_values: for variant in value, value.capitalize(), value.upper(): @@ -201,15 +233,19 @@ def _test_boolean_option(self, option): option.fset(command, None) self.assertEqual(option.fget(command), None) - for value in 13, b'bytes', 'string', object(): + for value in 13, b"bytes", "string", object(): try: option.fset(command, value) except ValueError: pass except BaseException as error: - self.fail(f'Expected ValueError when setting {option.name}={repr(value)}, but {type(error)} was raised') + self.fail( + f"Expected ValueError when setting {option.name}={repr(value)}, but {type(error)} was raised" + ) else: - self.fail(f'Expected ValueError, but {option.name}={repr(option.fget(command))} was accepted.') + self.fail( + f"Expected ValueError, but {option.name}={repr(option.fget(command))} was accepted." + ) if __name__ == "__main__": diff --git a/tests/searchcommands/test_configuration_settings.py b/tests/searchcommands/test_configuration_settings.py index 171a36166..9c4f4170f 100644 --- a/tests/searchcommands/test_configuration_settings.py +++ b/tests/searchcommands/test_configuration_settings.py @@ -30,12 +30,9 @@ from splunklib.searchcommands.decorators import Configuration - @pytest.mark.smoke class TestConfigurationSettings(TestCase): - def test_generating_command(self): - from splunklib.searchcommands import GeneratingCommand @Configuration() @@ -46,9 +43,7 @@ def generate(self): command = TestCommand() command._protocol_version = 1 - self.assertTrue( - list(command.configuration.items()), - [('generating', True)]) + self.assertTrue(list(command.configuration.items()), [("generating", True)]) self.assertIs(command.configuration.generates_timeorder, None) self.assertIs(command.configuration.generating, True) @@ -65,25 +60,32 @@ def generate(self): except AttributeError: pass except Exception as error: - self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') + self.fail(f"Expected AttributeError, not {type(error).__name__}: {error}") else: - self.fail('Expected AttributeError') + self.fail("Expected AttributeError") self.assertEqual( list(command.configuration.items()), - [('generates_timeorder', True), ('generating', True), ('local', True), ('retainsevents', True), - ('streaming', True)]) + [ + ("generates_timeorder", True), + ("generating", True), + ("local", True), + ("retainsevents", True), + ("streaming", True), + ], + ) command = TestCommand() command._protocol_version = 2 self.assertEqual( list(command.configuration.items()), - [('generating', True), ('type', 'stateful')]) + [("generating", True), ("type", "stateful")], + ) self.assertIs(command.configuration.distributed, False) self.assertIs(command.configuration.generating, True) - self.assertEqual(command.configuration.type, 'streaming') + self.assertEqual(command.configuration.type, "streaming") command.configuration.distributed = True @@ -92,16 +94,16 @@ def generate(self): except AttributeError: pass except Exception as error: - self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') + self.fail(f"Expected AttributeError, not {type(error).__name__}: {error}") else: - self.fail('Expected AttributeError') + self.fail("Expected AttributeError") self.assertEqual( list(command.configuration.items()), - [('generating', True), ('type', 'streaming')]) + [("generating", True), ("type", "streaming")], + ) def test_streaming_command(self): - from splunklib.searchcommands import StreamingCommand @Configuration() @@ -113,9 +115,7 @@ def stream(self, records): command._protocol_version = 1 - self.assertEqual( - list(command.configuration.items()), - [('streaming', True)]) + self.assertEqual(list(command.configuration.items()), [("streaming", True)]) self.assertIs(command.configuration.clear_required_fields, None) self.assertIs(command.configuration.local, None) @@ -126,47 +126,55 @@ def stream(self, records): command.configuration.clear_required_fields = True command.configuration.local = True command.configuration.overrides_timeorder = True - command.configuration.required_fields = ['field_1', 'field_2', 'field_3'] + command.configuration.required_fields = ["field_1", "field_2", "field_3"] try: command.configuration.streaming = False except AttributeError: pass except Exception as error: - self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') + self.fail(f"Expected AttributeError, not {type(error).__name__}: {error}") else: - self.fail('Expected AttributeError') + self.fail("Expected AttributeError") self.assertEqual( list(command.configuration.items()), - [('clear_required_fields', True), ('local', True), ('overrides_timeorder', True), - ('required_fields', ['field_1', 'field_2', 'field_3']), ('streaming', True)]) + [ + ("clear_required_fields", True), + ("local", True), + ("overrides_timeorder", True), + ("required_fields", ["field_1", "field_2", "field_3"]), + ("streaming", True), + ], + ) command = TestCommand() command._protocol_version = 2 - self.assertEqual( - list(command.configuration.items()), - [('type', 'streaming')]) + self.assertEqual(list(command.configuration.items()), [("type", "streaming")]) self.assertIs(command.configuration.distributed, True) - self.assertEqual(command.configuration.type, 'streaming') + self.assertEqual(command.configuration.type, "streaming") command.configuration.distributed = False - command.configuration.required_fields = ['field_1', 'field_2', 'field_3'] + command.configuration.required_fields = ["field_1", "field_2", "field_3"] try: - command.configuration.type = 'events' + command.configuration.type = "events" except AttributeError: pass except Exception as error: - self.fail(f'Expected AttributeError, not {type(error).__name__}: {error}') + self.fail(f"Expected AttributeError, not {type(error).__name__}: {error}") else: - self.fail('Expected AttributeError') + self.fail("Expected AttributeError") self.assertEqual( list(command.configuration.items()), - [('required_fields', ['field_1', 'field_2', 'field_3']), ('type', 'stateful')]) + [ + ("required_fields", ["field_1", "field_2", "field_3"]), + ("type", "stateful"), + ], + ) if __name__ == "__main__": diff --git a/tests/searchcommands/test_csc_apps.py b/tests/searchcommands/test_csc_apps.py index 64b03dc3e..b3eb18b47 100755 --- a/tests/searchcommands/test_csc_apps.py +++ b/tests/searchcommands/test_csc_apps.py @@ -20,13 +20,15 @@ from tests import testlib from splunklib import results + @pytest.mark.smoke class TestCSC(testlib.SDKTestCase): - def test_eventing_app(self): app_name = "eventing_app" - self.assertTrue(app_name in self.service.apps, msg="%s is not installed." % app_name) + self.assertTrue( + app_name in self.service.apps, msg="%s is not installed." % app_name + ) # Fetch the app app = self.service.apps[app_name] @@ -48,8 +50,8 @@ def test_eventing_app(self): self.assertEqual(access.modifiable, "1") self.assertEqual(access.owner, "nobody") self.assertEqual(access.sharing, "app") - self.assertEqual(access.perms.read, ['*']) - self.assertEqual(access.perms.write, ['admin', 'power']) + self.assertEqual(access.perms.read, ["*"]) + self.assertEqual(access.perms.write, ["admin", "power"]) self.assertEqual(access.removable, "0") self.assertEqual(content.author, "Splunk") @@ -62,8 +64,10 @@ def test_eventing_app(self): self.assertEqual(state.title, "eventing_app") jobs = self.service.jobs - stream = jobs.oneshot('search index="_internal" | head 4000 | eventingcsc status=200 | head 10', - output_mode='json') + stream = jobs.oneshot( + 'search index="_internal" | head 4000 | eventingcsc status=200 | head 10', + output_mode="json", + ) result = results.JSONResultsReader(stream) ds = list(result) @@ -75,7 +79,9 @@ def test_eventing_app(self): def test_generating_app(self): app_name = "generating_app" - self.assertTrue(app_name in self.service.apps, msg="%s is not installed." % app_name) + self.assertTrue( + app_name in self.service.apps, msg="%s is not installed." % app_name + ) # Fetch the app app = self.service.apps[app_name] @@ -97,13 +103,15 @@ def test_generating_app(self): self.assertEqual(access.modifiable, "1") self.assertEqual(access.owner, "nobody") self.assertEqual(access.sharing, "app") - self.assertEqual(access.perms.read, ['*']) - self.assertEqual(access.perms.write, ['admin', 'power']) + self.assertEqual(access.perms.read, ["*"]) + self.assertEqual(access.perms.write, ["admin", "power"]) self.assertEqual(access.removable, "0") self.assertEqual(content.author, "Splunk") self.assertEqual(content.configured, "0") - self.assertEqual(content.description, "Generating custom search commands example") + self.assertEqual( + content.description, "Generating custom search commands example" + ) self.assertEqual(content.label, "Generating App") self.assertEqual(content.version, "1.0.0") self.assertEqual(content.visible, "1") @@ -111,7 +119,7 @@ def test_generating_app(self): self.assertEqual(state.title, "generating_app") jobs = self.service.jobs - stream = jobs.oneshot('| generatingcsc count=4', output_mode='json') + stream = jobs.oneshot("| generatingcsc count=4", output_mode="json") result = results.JSONResultsReader(stream) ds = list(result) self.assertTrue(len(ds) == 4) @@ -119,7 +127,9 @@ def test_generating_app(self): def test_reporting_app(self): app_name = "reporting_app" - self.assertTrue(app_name in self.service.apps, msg="%s is not installed." % app_name) + self.assertTrue( + app_name in self.service.apps, msg="%s is not installed." % app_name + ) # Fetch the app app = self.service.apps[app_name] @@ -141,13 +151,15 @@ def test_reporting_app(self): self.assertEqual(access.modifiable, "1") self.assertEqual(access.owner, "nobody") self.assertEqual(access.sharing, "app") - self.assertEqual(access.perms.read, ['*']) - self.assertEqual(access.perms.write, ['admin', 'power']) + self.assertEqual(access.perms.read, ["*"]) + self.assertEqual(access.perms.write, ["admin", "power"]) self.assertEqual(access.removable, "0") self.assertEqual(content.author, "Splunk") self.assertEqual(content.configured, "0") - self.assertEqual(content.description, "Reporting custom search commands example") + self.assertEqual( + content.description, "Reporting custom search commands example" + ) self.assertEqual(content.label, "Reporting App") self.assertEqual(content.version, "1.0.0") self.assertEqual(content.visible, "1") @@ -158,8 +170,9 @@ def test_reporting_app(self): # All above 150 stream = jobs.oneshot( - '| makeresults count=10 | eval math=100, eng=100, cs=100 | reportingcsc cutoff=150 math eng cs', - output_mode='json') + "| makeresults count=10 | eval math=100, eng=100, cs=100 | reportingcsc cutoff=150 math eng cs", + output_mode="json", + ) result = results.JSONResultsReader(stream) ds = list(result) @@ -172,8 +185,9 @@ def test_reporting_app(self): # All below 150 stream = jobs.oneshot( - '| makeresults count=10 | eval math=45, eng=45, cs=45 | reportingcsc cutoff=150 math eng cs', - output_mode='json') + "| makeresults count=10 | eval math=45, eng=45, cs=45 | reportingcsc cutoff=150 math eng cs", + output_mode="json", + ) result = results.JSONResultsReader(stream) ds = list(result) @@ -187,7 +201,9 @@ def test_reporting_app(self): def test_streaming_app(self): app_name = "streaming_app" - self.assertTrue(app_name in self.service.apps, msg="%s is not installed." % app_name) + self.assertTrue( + app_name in self.service.apps, msg="%s is not installed." % app_name + ) # Fetch the app app = self.service.apps[app_name] @@ -209,13 +225,15 @@ def test_streaming_app(self): self.assertEqual(access.modifiable, "1") self.assertEqual(access.owner, "nobody") self.assertEqual(access.sharing, "app") - self.assertEqual(access.perms.read, ['*']) - self.assertEqual(access.perms.write, ['admin', 'power']) + self.assertEqual(access.perms.read, ["*"]) + self.assertEqual(access.perms.write, ["admin", "power"]) self.assertEqual(access.removable, "0") self.assertEqual(content.author, "Splunk") self.assertEqual(content.configured, "0") - self.assertEqual(content.description, "Streaming custom search commands example") + self.assertEqual( + content.description, "Streaming custom search commands example" + ) self.assertEqual(content.label, "Streaming App") self.assertEqual(content.version, "1.0.0") self.assertEqual(content.visible, "1") @@ -224,16 +242,19 @@ def test_streaming_app(self): jobs = self.service.jobs - stream = jobs.oneshot('| makeresults count=5 | eval celsius = 35 | streamingcsc', output_mode='json') + stream = jobs.oneshot( + "| makeresults count=5 | eval celsius = 35 | streamingcsc", + output_mode="json", + ) result = results.JSONResultsReader(stream) ds = list(result) self.assertTrue(len(ds) == 5) - self.assertTrue('_time' in ds[0]) - self.assertTrue('celsius' in ds[0]) - self.assertTrue('fahrenheit' in ds[0]) - self.assertTrue(ds[0]['celsius'] == '35') - self.assertTrue(ds[0]['fahrenheit'] == '95.0') + self.assertTrue("_time" in ds[0]) + self.assertTrue("celsius" in ds[0]) + self.assertTrue("fahrenheit" in ds[0]) + self.assertTrue(ds[0]["celsius"] == "35") + self.assertTrue(ds[0]["fahrenheit"] == "95.0") self.assertTrue(len(ds) == 5) diff --git a/tests/searchcommands/test_decorators.py b/tests/searchcommands/test_decorators.py index 3cc571dd9..533150ff5 100755 --- a/tests/searchcommands/test_decorators.py +++ b/tests/searchcommands/test_decorators.py @@ -33,154 +33,193 @@ @Configuration() class TestSearchCommand(SearchCommand): boolean = Option( - doc=''' + doc=""" **Syntax:** **boolean=**** - **Description:** A boolean value''', - validate=validators.Boolean()) + **Description:** A boolean value""", + validate=validators.Boolean(), + ) required_boolean = Option( - doc=''' + doc=""" **Syntax:** **boolean=**** - **Description:** A boolean value''', - require=True, validate=validators.Boolean()) + **Description:** A boolean value""", + require=True, + validate=validators.Boolean(), + ) aliased_required_boolean = Option( - doc=''' + doc=""" **Syntax:** **boolean=**** - **Description:** A boolean value''', - name='foo', require=True, validate=validators.Boolean()) + **Description:** A boolean value""", + name="foo", + require=True, + validate=validators.Boolean(), + ) code = Option( doc=''' **Syntax:** **code=**** **Description:** A Python expression, if mode == "eval", or statement, if mode == "exec"''', - validate=validators.Code()) + validate=validators.Code(), + ) required_code = Option( doc=''' **Syntax:** **code=**** **Description:** A Python expression, if mode == "eval", or statement, if mode == "exec"''', - require=True, validate=validators.Code()) + require=True, + validate=validators.Code(), + ) duration = Option( - doc=''' + doc=""" **Syntax:** **duration=**** - **Description:** A length of time''', - validate=validators.Duration()) + **Description:** A length of time""", + validate=validators.Duration(), + ) required_duration = Option( - doc=''' + doc=""" **Syntax:** **duration=**** - **Description:** A length of time''', - require=True, validate=validators.Duration()) + **Description:** A length of time""", + require=True, + validate=validators.Duration(), + ) fieldname = Option( - doc=''' + doc=""" **Syntax:** **fieldname=**** - **Description:** Name of a field''', - validate=validators.Fieldname()) + **Description:** Name of a field""", + validate=validators.Fieldname(), + ) required_fieldname = Option( - doc=''' + doc=""" **Syntax:** **fieldname=**** - **Description:** Name of a field''', - require=True, validate=validators.Fieldname()) + **Description:** Name of a field""", + require=True, + validate=validators.Fieldname(), + ) file = Option( - doc=''' + doc=""" **Syntax:** **file=**** - **Description:** Name of a file''', - validate=validators.File()) + **Description:** Name of a file""", + validate=validators.File(), + ) required_file = Option( - doc=''' + doc=""" **Syntax:** **file=**** - **Description:** Name of a file''', - require=True, validate=validators.File()) + **Description:** Name of a file""", + require=True, + validate=validators.File(), + ) integer = Option( - doc=''' + doc=""" **Syntax:** **integer=**** - **Description:** An integer value''', - validate=validators.Integer()) + **Description:** An integer value""", + validate=validators.Integer(), + ) required_integer = Option( - doc=''' + doc=""" **Syntax:** **integer=**** - **Description:** An integer value''', - require=True, validate=validators.Integer()) + **Description:** An integer value""", + require=True, + validate=validators.Integer(), + ) float = Option( - doc=''' + doc=""" **Syntax:** **float=**** - **Description:** An float value''', - validate=validators.Float()) + **Description:** An float value""", + validate=validators.Float(), + ) required_float = Option( - doc=''' + doc=""" **Syntax:** **float=**** - **Description:** An float value''', - require=True, validate=validators.Float()) + **Description:** An float value""", + require=True, + validate=validators.Float(), + ) map = Option( - doc=''' + doc=""" **Syntax:** **map=**** - **Description:** A mapping from one value to another''', - validate=validators.Map(foo=1, bar=2, test=3)) + **Description:** A mapping from one value to another""", + validate=validators.Map(foo=1, bar=2, test=3), + ) required_map = Option( - doc=''' + doc=""" **Syntax:** **map=**** - **Description:** A mapping from one value to another''', - require=True, validate=validators.Map(foo=1, bar=2, test=3)) + **Description:** A mapping from one value to another""", + require=True, + validate=validators.Map(foo=1, bar=2, test=3), + ) match = Option( - doc=''' + doc=""" **Syntax:** **match=**** - **Description:** A value that matches a regular expression pattern''', - validate=validators.Match('social security number', r'\d{3}-\d{2}-\d{4}')) + **Description:** A value that matches a regular expression pattern""", + validate=validators.Match("social security number", r"\d{3}-\d{2}-\d{4}"), + ) required_match = Option( - doc=''' + doc=""" **Syntax:** **required_match=**** - **Description:** A value that matches a regular expression pattern''', - require=True, validate=validators.Match('social security number', r'\d{3}-\d{2}-\d{4}')) + **Description:** A value that matches a regular expression pattern""", + require=True, + validate=validators.Match("social security number", r"\d{3}-\d{2}-\d{4}"), + ) optionname = Option( - doc=''' + doc=""" **Syntax:** **optionname=**** - **Description:** The name of an option (used internally)''', - validate=validators.OptionName()) + **Description:** The name of an option (used internally)""", + validate=validators.OptionName(), + ) required_optionname = Option( - doc=''' + doc=""" **Syntax:** **optionname=**** - **Description:** The name of an option (used internally)''', - require=True, validate=validators.OptionName()) + **Description:** The name of an option (used internally)""", + require=True, + validate=validators.OptionName(), + ) regularexpression = Option( - doc=''' + doc=""" **Syntax:** **regularexpression=**** - **Description:** Regular expression pattern to match''', - validate=validators.RegularExpression()) + **Description:** Regular expression pattern to match""", + validate=validators.RegularExpression(), + ) required_regularexpression = Option( - doc=''' + doc=""" **Syntax:** **regularexpression=**** - **Description:** Regular expression pattern to match''', - require=True, validate=validators.RegularExpression()) + **Description:** Regular expression pattern to match""", + require=True, + validate=validators.RegularExpression(), + ) set = Option( - doc=''' + doc=""" **Syntax:** **set=**** - **Description:** A member of a set''', - validate=validators.Set('foo', 'bar', 'test')) + **Description:** A member of a set""", + validate=validators.Set("foo", "bar", "test"), + ) required_set = Option( - doc=''' + doc=""" **Syntax:** **set=**** - **Description:** A member of a set''', - require=True, validate=validators.Set('foo', 'bar', 'test')) + **Description:** A member of a set""", + require=True, + validate=validators.Set("foo", "bar", "test"), + ) class ConfigurationSettings(SearchCommand.ConfigurationSettings): @classmethod @@ -190,15 +229,14 @@ def fix_up(cls, command_class): @pytest.mark.smoke class TestDecorators(TestCase): - def setUp(self): TestCase.setUp(self) def test_configuration(self): - def new_configuration_settings_class(setting_name=None, setting_value=None): - - @Configuration(**{} if setting_name is None else {setting_name: setting_value}) + @Configuration( + **{} if setting_name is None else {setting_name: setting_value} + ) class ConfiguredSearchCommand(SearchCommand): class ConfigurationSettings(SearchCommand.ConfigurationSettings): clear_required_fields = ConfigurationSetting() @@ -222,47 +260,53 @@ def fix_up(cls, command_class): return ConfiguredSearchCommand.ConfigurationSettings for name, values, error_values in ( - ('clear_required_fields', - (True, False), - (None, 'anything other than a bool')), - ('distributed', - (True, False), - (None, 'anything other than a bool')), - ('generates_timeorder', - (True, False), - (None, 'anything other than a bool')), - ('generating', - (True, False), - (None, 'anything other than a bool')), - ('maxinputs', - (0, 50000, sys.maxsize), - (None, -1, sys.maxsize + 1, 'anything other than an int')), - ('overrides_timeorder', - (True, False), - (None, 'anything other than a bool')), - ('required_fields', - (['field_1', 'field_2'], set(['field_1', 'field_2']), ('field_1', 'field_2')), - (None, 0xdead, {'foo': 1, 'bar': 2})), - ('requires_preop', - (True, False), - (None, 'anything other than a bool')), - ('retainsevents', - (True, False), - (None, 'anything other than a bool')), - ('run_in_preview', - (True, False), - (None, 'anything other than a bool')), - ('streaming', - (True, False), - (None, 'anything other than a bool')), - ('streaming_preop', - ('some unicode string', b'some byte string'), - (None, 0xdead)), - ('type', - # TODO: Do we need to validate byte versions of these strings? - ('events', 'reporting', 'streaming'), - ('eventing', 0xdead))): - + ( + "clear_required_fields", + (True, False), + (None, "anything other than a bool"), + ), + ("distributed", (True, False), (None, "anything other than a bool")), + ( + "generates_timeorder", + (True, False), + (None, "anything other than a bool"), + ), + ("generating", (True, False), (None, "anything other than a bool")), + ( + "maxinputs", + (0, 50000, sys.maxsize), + (None, -1, sys.maxsize + 1, "anything other than an int"), + ), + ( + "overrides_timeorder", + (True, False), + (None, "anything other than a bool"), + ), + ( + "required_fields", + ( + ["field_1", "field_2"], + set(["field_1", "field_2"]), + ("field_1", "field_2"), + ), + (None, 0xDEAD, {"foo": 1, "bar": 2}), + ), + ("requires_preop", (True, False), (None, "anything other than a bool")), + ("retainsevents", (True, False), (None, "anything other than a bool")), + ("run_in_preview", (True, False), (None, "anything other than a bool")), + ("streaming", (True, False), (None, "anything other than a bool")), + ( + "streaming_preop", + ("some unicode string", b"some byte string"), + (None, 0xDEAD), + ), + ( + "type", + # TODO: Do we need to validate byte versions of these strings? + ("events", "reporting", "streaming"), + ("eventing", 0xDEAD), + ), + ): for value in values: settings_class = new_configuration_settings_class(name, value) @@ -270,7 +314,7 @@ def fix_up(cls, command_class): self.assertIsInstance(getattr(settings_class, name), property) # Backing field exists on the settings class and it holds the correct value - backing_field_name = '_' + name + backing_field_name = "_" + name self.assertEqual(getattr(settings_class, backing_field_name), value) settings_instance = settings_class(command=None) @@ -291,21 +335,25 @@ def fix_up(cls, command_class): try: new_configuration_settings_class(name, value) except Exception as error: - self.assertIsInstance(error, ValueError, - f'Expected ValueError, not {type(error).__name__}({error}) for {name}={repr(value)}') + self.assertIsInstance( + error, + ValueError, + f"Expected ValueError, not {type(error).__name__}({error}) for {name}={repr(value)}", + ) else: - self.fail(f'Expected ValueError, not success for {name}={repr(value)}') + self.fail( + f"Expected ValueError, not success for {name}={repr(value)}" + ) settings_class = new_configuration_settings_class() settings_instance = settings_class(command=None) self.assertRaises(ValueError, setattr, settings_instance, name, value) def test_new_configuration_setting(self): - class Test: generating = ConfigurationSetting() - @ConfigurationSetting(name='required_fields') + @ConfigurationSetting(name="required_fields") def some_name__other_than_required_fields(self): pass @@ -324,8 +372,8 @@ def streaming_preop(self, value): ConfigurationSetting.fix_up(Test, {}) test = Test() - self.assertFalse(hasattr(Test, '_generating')) - self.assertFalse(hasattr(test, '_generating')) + self.assertFalse(hasattr(Test, "_generating")) + self.assertFalse(hasattr(test, "_generating")) self.assertIsNone(test.generating) Test._generating = True @@ -336,55 +384,73 @@ def streaming_preop(self, value): self.assertIs(Test._generating, True) self.assertIs(test._generating, False) - self.assertRaises(ValueError, Test.generating.fset, test, 'any type other than bool') + self.assertRaises( + ValueError, Test.generating.fset, test, "any type other than bool" + ) def test_option(self): - - rebase_environment('app_with_logging_configuration') + rebase_environment("app_with_logging_configuration") presets = [ - 'logging_configuration=' + json_encode_string(environment.logging_configuration), + "logging_configuration=" + + json_encode_string(environment.logging_configuration), 'logging_level="WARNING"', 'record="f"', - 'show_configuration="f"'] + 'show_configuration="f"', + ] command = TestSearchCommand() options = command.options options.reset() missing = options.get_missing() - self.assertListEqual(missing, [option.name for option in options.values() if option.is_required]) - self.assertListEqual(presets, [str(option) for option in options.values() if option.value is not None]) - self.assertListEqual(presets, [str(option) for option in options.values() if str(option) != option.name + '=None']) + self.assertListEqual( + missing, [option.name for option in options.values() if option.is_required] + ) + self.assertListEqual( + presets, + [str(option) for option in options.values() if option.value is not None], + ) + self.assertListEqual( + presets, + [ + str(option) + for option in options.values() + if str(option) != option.name + "=None" + ], + ) test_option_values = { - validators.Boolean: ('0', 'non-boolean value'), - validators.Code: ('foo == "bar"', 'bad code'), - validators.Duration: ('24:59:59', 'non-duration value'), - validators.Fieldname: ('some.field_name', 'non-fieldname value'), - validators.File: (__file__, 'non-existent file'), - validators.Integer: ('100', 'non-integer value'), - validators.Float: ('99.9', 'non-float value'), - validators.List: ('a,b,c', '"non-list value'), - validators.Map: ('foo', 'non-existent map entry'), - validators.Match: ('123-45-6789', 'not a social security number'), - validators.OptionName: ('some_option_name', 'non-option name value'), - validators.RegularExpression: ('\\s+', '(poorly formed regular expression'), - validators.Set: ('bar', 'non-existent set entry')} + validators.Boolean: ("0", "non-boolean value"), + validators.Code: ('foo == "bar"', "bad code"), + validators.Duration: ("24:59:59", "non-duration value"), + validators.Fieldname: ("some.field_name", "non-fieldname value"), + validators.File: (__file__, "non-existent file"), + validators.Integer: ("100", "non-integer value"), + validators.Float: ("99.9", "non-float value"), + validators.List: ("a,b,c", '"non-list value'), + validators.Map: ("foo", "non-existent map entry"), + validators.Match: ("123-45-6789", "not a social security number"), + validators.OptionName: ("some_option_name", "non-option name value"), + validators.RegularExpression: ("\\s+", "(poorly formed regular expression"), + validators.Set: ("bar", "non-existent set entry"), + } for option in options.values(): validator = option.validator if validator is None: - self.assertIn(option.name, ['logging_configuration', 'logging_level']) + self.assertIn(option.name, ["logging_configuration", "logging_level"]) continue legal_value, illegal_value = test_option_values[type(validator)] option.value = legal_value self.assertEqual( - validator.format(option.value), validator.format(validator.__call__(legal_value)), - f"{option.name}={legal_value}") + validator.format(option.value), + validator.format(validator.__call__(legal_value)), + f"{option.name}={legal_value}", + ) try: option.value = illegal_value @@ -392,40 +458,43 @@ def test_option(self): pass except BaseException as error: self.assertFalse( - f'Expected ValueError for {option.name}={illegal_value}, not this {type(error).__name__}: {error}') + f"Expected ValueError for {option.name}={illegal_value}, not this {type(error).__name__}: {error}" + ) else: - self.assertFalse(f'Expected ValueError for {option.name}={illegal_value}, not a pass.') + self.assertFalse( + f"Expected ValueError for {option.name}={illegal_value}, not a pass." + ) expected = { - 'foo': False, - 'boolean': False, - 'code': 'foo == \"bar\"', - 'duration': 89999, - 'fieldname': 'some.field_name', - 'file': str(repr(__file__)), - 'integer': 100, - 'float': 99.9, - 'logging_configuration': environment.logging_configuration, - 'logging_level': 'WARNING', - 'map': 'foo', - 'match': '123-45-6789', - 'optionname': 'some_option_name', - 'record': False, - 'regularexpression': '\\s+', - 'required_boolean': False, - 'required_code': 'foo == \"bar\"', - 'required_duration': 89999, - 'required_fieldname': 'some.field_name', - 'required_file': str(repr(__file__)), - 'required_integer': 100, - 'required_float': 99.9, - 'required_map': 'foo', - 'required_match': '123-45-6789', - 'required_optionname': 'some_option_name', - 'required_regularexpression': '\\s+', - 'required_set': 'bar', - 'set': 'bar', - 'show_configuration': False, + "foo": False, + "boolean": False, + "code": 'foo == "bar"', + "duration": 89999, + "fieldname": "some.field_name", + "file": str(repr(__file__)), + "integer": 100, + "float": 99.9, + "logging_configuration": environment.logging_configuration, + "logging_level": "WARNING", + "map": "foo", + "match": "123-45-6789", + "optionname": "some_option_name", + "record": False, + "regularexpression": "\\s+", + "required_boolean": False, + "required_code": 'foo == "bar"', + "required_duration": 89999, + "required_fieldname": "some.field_name", + "required_file": str(repr(__file__)), + "required_integer": 100, + "required_float": 99.9, + "required_map": "foo", + "required_match": "123-45-6789", + "required_optionname": "some_option_name", + "required_regularexpression": "\\s+", + "required_set": "bar", + "set": "bar", + "show_configuration": False, } self.maxDiff = None @@ -435,27 +504,36 @@ def test_option(self): for x in command.options.values(): # isinstance doesn't work for some reason - if type(x.value).__name__ == 'Code': + if type(x.value).__name__ == "Code": self.assertEqual(expected[x.name], x.value.source) - elif type(x.validator).__name__ == 'Map': - self.assertEqual(expected[x.name], invert(x.validator.membership)[x.value]) - elif type(x.validator).__name__ == 'RegularExpression': + elif type(x.validator).__name__ == "Map": + self.assertEqual( + expected[x.name], invert(x.validator.membership)[x.value] + ) + elif type(x.validator).__name__ == "RegularExpression": self.assertEqual(expected[x.name], x.value.pattern) elif isinstance(x.value, TextIOWrapper): self.assertEqual(expected[x.name], f"'{x.value.name}'") - elif not isinstance(x.value, (bool,) + (float,) + (str,) + (bytes,) + tuplewrap(int)): + elif not isinstance( + x.value, (bool,) + (float,) + (str,) + (bytes,) + tuplewrap(int) + ): self.assertEqual(expected[x.name], repr(x.value)) else: self.assertEqual(expected[x.name], x.value) expected = ( 'foo="f" boolean="f" code="foo == \\"bar\\"" duration="24:59:59" fieldname="some.field_name" ' - 'file=' + json_encode_string(__file__) + ' float="99.9" integer="100" map="foo" match="123-45-6789" ' + "file=" + + json_encode_string(__file__) + + ' float="99.9" integer="100" map="foo" match="123-45-6789" ' 'optionname="some_option_name" record="f" regularexpression="\\\\s+" required_boolean="f" ' 'required_code="foo == \\"bar\\"" required_duration="24:59:59" required_fieldname="some.field_name" ' - 'required_file=' + json_encode_string(__file__) + ' required_float="99.9" required_integer="100" required_map="foo" ' + "required_file=" + + json_encode_string(__file__) + + ' required_float="99.9" required_integer="100" required_map="foo" ' 'required_match="123-45-6789" required_optionname="some_option_name" required_regularexpression="\\\\s+" ' - 'required_set="bar" set="bar" show_configuration="f"') + 'required_set="bar" set="bar" show_configuration="f"' + ) observed = str(command.options) diff --git a/tests/searchcommands/test_generator_command.py b/tests/searchcommands/test_generator_command.py index af103977a..c2b5621b1 100644 --- a/tests/searchcommands/test_generator_command.py +++ b/tests/searchcommands/test_generator_command.py @@ -10,12 +10,12 @@ def test_simple_generator(): class GeneratorTest(GeneratingCommand): def generate(self): for num in range(1, 10): - yield {'_time': time.time(), 'event_index': num} + yield {"_time": time.time(), "event_index": num} generator = GeneratorTest() in_stream = io.BytesIO() in_stream.write(chunky.build_getinfo_chunk()) - in_stream.write(chunky.build_chunk({'action': 'execute'})) + in_stream.write(chunky.build_chunk({"action": "execute"})) in_stream.seek(0) out_stream = io.BytesIO() generator._process_protocol_v2([], in_stream, out_stream) @@ -75,14 +75,14 @@ def generate(self): generator = GeneratorTest() in_stream = io.BytesIO() in_stream.write(chunky.build_getinfo_chunk()) - in_stream.write(chunky.build_chunk({'action': 'execute'})) + in_stream.write(chunky.build_chunk({"action": "execute"})) in_stream.seek(0) out_stream = io.BytesIO() generator._process_protocol_v2([], in_stream, out_stream) out_stream.seek(0) ds = chunky.ChunkedDataStream(out_stream) - fieldnames_expected = {'_time', 'one', 'two', 'three', 'four', 'five'} + fieldnames_expected = {"_time", "one", "two", "three", "four", "five"} fieldnames_actual = set() for chunk in ds: for row in chunk.data: diff --git a/tests/searchcommands/test_internals_v1.py b/tests/searchcommands/test_internals_v1.py index 6e41844ff..7ac8e50f8 100755 --- a/tests/searchcommands/test_internals_v1.py +++ b/tests/searchcommands/test_internals_v1.py @@ -21,7 +21,11 @@ from functools import reduce import pytest -from splunklib.searchcommands.internals import CommandLineParser, InputHeader, RecordWriterV1 +from splunklib.searchcommands.internals import ( + CommandLineParser, + InputHeader, + RecordWriterV1, +) from splunklib.searchcommands.decorators import Configuration, Option from splunklib.searchcommands.validators import Boolean @@ -34,27 +38,30 @@ def setUp(self): TestCase.setUp(self) def test_command_line_parser(self): - @Configuration() class TestCommandLineParserCommand(SearchCommand): - required_option = Option(validate=Boolean(), require=True) unnecessary_option = Option(validate=Boolean(), default=True, require=False) class ConfigurationSettings(SearchCommand.ConfigurationSettings): - @classmethod - def fix_up(cls, command_class): pass + def fix_up(cls, command_class): + pass # Command line without fieldnames - options = ['required_option=true', 'unnecessary_option=false'] + options = ["required_option=true", "unnecessary_option=false"] command = TestCommandLineParserCommand() CommandLineParser.parse(command, options) for option in command.options.values(): - if option.name in ['logging_configuration', 'logging_level', 'record', 'show_configuration']: + if option.name in [ + "logging_configuration", + "logging_level", + "record", + "show_configuration", + ]: self.assertFalse(option.is_set) continue self.assertTrue(option.is_set) @@ -65,13 +72,18 @@ def fix_up(cls, command_class): pass # Command line with fieldnames - fieldnames = ['field_1', 'field_2', 'field_3'] + fieldnames = ["field_1", "field_2", "field_3"] command = TestCommandLineParserCommand() CommandLineParser.parse(command, options + fieldnames) for option in command.options.values(): - if option.name in ['logging_configuration', 'logging_level', 'record', 'show_configuration']: + if option.name in [ + "logging_configuration", + "logging_level", + "record", + "show_configuration", + ]: self.assertFalse(option.is_set) continue self.assertTrue(option.is_set) @@ -83,11 +95,16 @@ def fix_up(cls, command_class): pass # Command line without any unnecessary options command = TestCommandLineParserCommand() - CommandLineParser.parse(command, ['required_option=true'] + fieldnames) + CommandLineParser.parse(command, ["required_option=true"] + fieldnames) for option in command.options.values(): - if option.name in ['unnecessary_option', 'logging_configuration', 'logging_level', 'record', - 'show_configuration']: + if option.name in [ + "unnecessary_option", + "logging_configuration", + "logging_level", + "record", + "show_configuration", + ]: self.assertFalse(option.is_set) continue self.assertTrue(option.is_set) @@ -98,27 +115,32 @@ def fix_up(cls, command_class): pass # Command line with missing required options, with or without fieldnames or unnecessary options - options = ['unnecessary_option=true'] - self.assertRaises(ValueError, CommandLineParser.parse, command, options + fieldnames) + options = ["unnecessary_option=true"] + self.assertRaises( + ValueError, CommandLineParser.parse, command, options + fieldnames + ) self.assertRaises(ValueError, CommandLineParser.parse, command, options) self.assertRaises(ValueError, CommandLineParser.parse, command, []) # Command line with unrecognized options - self.assertRaises(ValueError, CommandLineParser.parse, command, - ['unrecognized_option_1=foo', 'unrecognized_option_2=bar']) + self.assertRaises( + ValueError, + CommandLineParser.parse, + command, + ["unrecognized_option_1=foo", "unrecognized_option_2=bar"], + ) # Command line with a variety of quoted/escaped text options @Configuration() class TestCommandLineParserCommand(SearchCommand): - text = Option() class ConfigurationSettings(SearchCommand.ConfigurationSettings): - @classmethod - def fix_up(cls, command_class): pass + def fix_up(cls, command_class): + pass strings = [ r'"foo bar"', @@ -126,22 +148,23 @@ def fix_up(cls, command_class): pass r'"foo\\bar"', r'"""foo bar"""', r'"\"foo bar\""', - r'Hello\ World!', - r'\"Hello\ World!\"'] + r"Hello\ World!", + r"\"Hello\ World!\"", + ] expected_values = [ - r'foo bar', - r'foo/bar', - r'foo\bar', + r"foo bar", + r"foo/bar", + r"foo\bar", r'"foo bar"', r'"foo bar"', - r'Hello World!', - r'"Hello World!"' + r"Hello World!", + r'"Hello World!"', ] for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() - argv = ['text', '=', string] + argv = ["text", "=", string] CommandLineParser.parse(command, argv) self.assertEqual(command.text, expected_value) @@ -153,17 +176,12 @@ def fix_up(cls, command_class): pass for string, expected_value in zip(strings, expected_values): command = TestCommandLineParserCommand() - argv = ['text', '=', string] + strings + argv = ["text", "=", string] + strings CommandLineParser.parse(command, argv) self.assertEqual(command.text, expected_value) self.assertEqual(command.fieldnames, expected_values) - strings = [ - 'some\\ string\\', - r'some\ string"', - r'"some string', - r'some"string' - ] + strings = ["some\\ string\\", r'some\ string"', r'"some string', r'some"string'] for string in strings: command = TestCommandLineParserCommand() @@ -174,8 +192,8 @@ def test_command_line_parser_unquote(self): parser = CommandLineParser options = [ - r'foo', # unquoted string with no escaped characters - r'fo\o\ b\"a\\r', # unquoted string with some escaped characters + r"foo", # unquoted string with no escaped characters + r"fo\o\ b\"a\\r", # unquoted string with some escaped characters r'"foo"', # quoted string with no special characters r'"""foobar1"""', # quoted string with quotes escaped like this: "" r'"\"foobar2\""', # quoted string with quotes escaped like this: \" @@ -184,24 +202,26 @@ def test_command_line_parser_unquote(self): r'"\\foobar"', # quoted string with an escaped backslash r'"foo \\ bar"', # quoted string with an escaped backslash r'"foobar\\"', # quoted string with an escaped backslash - r'foo\\\bar', # quoted string with an escaped backslash and an escaped 'b' + r"foo\\\bar", # quoted string with an escaped backslash and an escaped 'b' r'""', # pair of quotes - r''] # empty string + r"", + ] # empty string expected = [ - r'foo', + r"foo", r'foo b"a\r', - r'foo', + r"foo", r'"foobar1"', r'"foobar2"', r'foo "x" bar', r'foo "x" bar', - '\\foobar', - r'foo \ bar', - 'foobar\\', - r'foo\bar', - r'', - r''] + "\\foobar", + r"foo \ bar", + "foobar\\", + r"foo\bar", + r"", + r"", + ] # Command line with an assortment of string values @@ -213,15 +233,14 @@ def test_command_line_parser_unquote(self): self.assertRaises(SyntaxError, parser.unquote, '"') self.assertRaises(SyntaxError, parser.unquote, '"foo') self.assertRaises(SyntaxError, parser.unquote, 'foo"') - self.assertRaises(SyntaxError, parser.unquote, 'foo\\') + self.assertRaises(SyntaxError, parser.unquote, "foo\\") def test_input_header(self): - # No items input_header = InputHeader() - with closing(StringIO('\r\n')) as input_file: + with closing(StringIO("\r\n")) as input_file: input_header.read(input_file) self.assertEqual(len(input_header), 0) @@ -230,14 +249,18 @@ def test_input_header(self): input_header = InputHeader() - with closing(StringIO('this%20is%20an%20unnamed%20single-line%20item\n\n')) as input_file: + with closing( + StringIO("this%20is%20an%20unnamed%20single-line%20item\n\n") + ) as input_file: input_header.read(input_file) self.assertEqual(len(input_header), 0) input_header = InputHeader() - with closing(StringIO('this%20is%20an%20unnamed\nmulti-\nline%20item\n\n')) as input_file: + with closing( + StringIO("this%20is%20an%20unnamed\nmulti-\nline%20item\n\n") + ) as input_file: input_header.read(input_file) self.assertEqual(len(input_header), 0) @@ -246,41 +269,51 @@ def test_input_header(self): input_header = InputHeader() - with closing(StringIO('Foo:this%20is%20a%20single-line%20item\n\n')) as input_file: + with closing( + StringIO("Foo:this%20is%20a%20single-line%20item\n\n") + ) as input_file: input_header.read(input_file) self.assertEqual(len(input_header), 1) - self.assertEqual(input_header['Foo'], 'this is a single-line item') + self.assertEqual(input_header["Foo"], "this is a single-line item") input_header = InputHeader() - with closing(StringIO('Bar:this is a\nmulti-\nline item\n\n')) as input_file: + with closing(StringIO("Bar:this is a\nmulti-\nline item\n\n")) as input_file: input_header.read(input_file) self.assertEqual(len(input_header), 1) - self.assertEqual(input_header['Bar'], 'this is a\nmulti-\nline item') + self.assertEqual(input_header["Bar"], "this is a\nmulti-\nline item") # The infoPath item (which is the path to a file that we open for reads) input_header = InputHeader() - with closing(StringIO('infoPath:non-existent.csv\n\n')) as input_file: + with closing(StringIO("infoPath:non-existent.csv\n\n")) as input_file: input_header.read(input_file) self.assertEqual(len(input_header), 1) - self.assertEqual(input_header['infoPath'], 'non-existent.csv') + self.assertEqual(input_header["infoPath"], "non-existent.csv") # Set of named items collection = { - 'word_list': 'hello\nworld\n!', - 'word_1': 'hello', - 'word_2': 'world', - 'word_3': '!', - 'sentence': 'hello world!'} + "word_list": "hello\nworld\n!", + "word_1": "hello", + "word_2": "world", + "word_3": "!", + "sentence": "hello world!", + } input_header = InputHeader() - text = reduce(lambda value, item: value + f'{item[0]}:{item[1]}\n', collection.items(), '') + '\n' + text = ( + reduce( + lambda value, item: value + f"{item[0]}:{item[1]}\n", + collection.items(), + "", + ) + + "\n" + ) with closing(StringIO(text)) as input_file: input_header.read(input_file) @@ -289,7 +322,7 @@ def test_input_header(self): # Set of named items with an unnamed item at the beginning (the only place that an unnamed item can appear) - with closing(StringIO('unnamed item\n' + text)) as input_file: + with closing(StringIO("unnamed item\n" + text)) as input_file: input_header.read(input_file) self.assertDictEqual(input_header, collection) @@ -301,13 +334,12 @@ def test_input_header(self): self.assertEqual(sorted(input_header.values()), sorted(collection.values())) def test_messages_header(self): - @Configuration() class TestMessagesHeaderCommand(SearchCommand): class ConfigurationSettings(SearchCommand.ConfigurationSettings): - @classmethod - def fix_up(cls, command_class): pass + def fix_up(cls, command_class): + pass command = TestMessagesHeaderCommand() command._protocol_version = 1 @@ -315,11 +347,12 @@ def fix_up(cls, command_class): pass command._record_writer = RecordWriterV1(output_buffer) messages = [ - (command.write_debug, 'debug_message'), - (command.write_error, 'error_message'), - (command.write_fatal, 'fatal_message'), - (command.write_info, 'info_message'), - (command.write_warning, 'warning_message')] + (command.write_debug, "debug_message"), + (command.write_error, "error_message"), + (command.write_fatal, "fatal_message"), + (command.write_info, "info_message"), + (command.write_warning, "warning_message"), + ] for write, message in messages: write(message) @@ -327,14 +360,15 @@ def fix_up(cls, command_class): pass command.finish() expected = ( - 'debug_message=debug_message\r\n' - 'error_message=error_message\r\n' - 'error_message=fatal_message\r\n' - 'info_message=info_message\r\n' - 'warn_message=warning_message\r\n' - '\r\n') - - self.assertEqual(output_buffer.getvalue().decode('utf-8'), expected) + "debug_message=debug_message\r\n" + "error_message=error_message\r\n" + "error_message=fatal_message\r\n" + "info_message=info_message\r\n" + "warn_message=warning_message\r\n" + "\r\n" + ) + + self.assertEqual(output_buffer.getvalue().decode("utf-8"), expected) _package_path = os.path.dirname(__file__) diff --git a/tests/searchcommands/test_internals_v2.py b/tests/searchcommands/test_internals_v2.py index 722aaae24..0a935b9e4 100755 --- a/tests/searchcommands/test_internals_v2.py +++ b/tests/searchcommands/test_internals_v2.py @@ -34,7 +34,12 @@ from collections import OrderedDict from collections import namedtuple, deque -from splunklib.searchcommands.internals import MetadataDecoder, MetadataEncoder, Recorder, RecordWriterV2 +from splunklib.searchcommands.internals import ( + MetadataDecoder, + MetadataEncoder, + Recorder, + RecordWriterV2, +) from splunklib.searchcommands import SearchMetric from io import BytesIO import pickle @@ -58,7 +63,6 @@ def random_bytes(): def random_dict(): - # We do not call random_bytes because the JSONDecoder raises this UnicodeDecodeError when it encounters # bytes outside the UTF-8 character set: # @@ -69,7 +73,13 @@ def random_dict(): # contain utf-8 encoded byte strings or--better still--unicode strings. This is because the json package # converts all bytes strings to unicode strings before serializing them. - return OrderedDict((('a', random_float()), ('b', random_unicode()), ('福 酒吧', OrderedDict((('fu', random_float()), ('bar', random_float())))))) + return OrderedDict( + ( + ("a", random_float()), + ("b", random_unicode()), + ("福 酒吧", OrderedDict((("fu", random_float()), ("bar", random_float())))), + ) + ) def random_float(): @@ -89,19 +99,25 @@ def random_list(population, *args): def random_unicode(): - return ''.join([str(x) for x in random.sample(list(range(MAX_NARROW_UNICODE)), random.randint(0, max_length))]) + return "".join( + [ + str(x) + for x in random.sample( + list(range(MAX_NARROW_UNICODE)), random.randint(0, max_length) + ) + ] + ) + # endregion @pytest.mark.smoke class TestInternals(TestCase): - def setUp(self): TestCase.setUp(self) def test_object_view(self): - decoder = MetadataDecoder() view = decoder.decode(self._json_input) @@ -111,51 +127,67 @@ def test_object_view(self): self.assertEqual(self._json_input, json_output) def test_record_writer_with_random_data(self, save_recording=False): - # Confirmed: [minint, maxint) covers the full range of values that xrange allows # RecordWriter writes apps in units of maxresultrows records. Default: 50,0000. # Partial results are written when the record count reaches maxresultrows. - writer = RecordWriterV2(BytesIO(), maxresultrows=10) # small for the purposes of this unit test + writer = RecordWriterV2( + BytesIO(), maxresultrows=10 + ) # small for the purposes of this unit test test_data = OrderedDict() - fieldnames = ['_serial', '_time', 'random_bytes', 'random_dict', 'random_integers', 'random_unicode'] - test_data['fieldnames'] = fieldnames - test_data['values'] = [] + fieldnames = [ + "_serial", + "_time", + "random_bytes", + "random_dict", + "random_integers", + "random_unicode", + ] + test_data["fieldnames"] = fieldnames + test_data["values"] = [] write_record = writer.write_record for serial_number in range(0, 31): - values = [serial_number, time(), random_bytes(), random_dict(), random_integers(), random_unicode()] + values = [ + serial_number, + time(), + random_bytes(), + random_dict(), + random_integers(), + random_unicode(), + ] record = OrderedDict(list(zip(fieldnames, values))) - #try: + # try: write_record(record) - #except Exception as error: + # except Exception as error: # self.fail(error) - test_data['values'].append(values) + test_data["values"].append(values) # RecordWriter accumulates inspector messages and metrics until maxresultrows are written, a partial result # is produced or we're finished messages = [ - ('debug', random_unicode()), - ('error', random_unicode()), - ('fatal', random_unicode()), - ('info', random_unicode()), - ('warn', random_unicode())] + ("debug", random_unicode()), + ("error", random_unicode()), + ("fatal", random_unicode()), + ("info", random_unicode()), + ("warn", random_unicode()), + ] - test_data['messages'] = messages + test_data["messages"] = messages for message_type, message_text in messages: - writer.write_message(message_type, '{}', message_text) + writer.write_message(message_type, "{}", message_text) metrics = { - 'metric-1': SearchMetric(1, 2, 3, 4), - 'metric-2': SearchMetric(5, 6, 7, 8) + "metric-1": SearchMetric(1, 2, 3, 4), + "metric-2": SearchMetric(5, 6, 7, 8), } - test_data['metrics'] = metrics + test_data["metrics"] = metrics for name, metric in metrics.items(): writer.write_metric(name, metric) @@ -169,11 +201,14 @@ def test_record_writer_with_random_data(self, save_recording=False): fieldnames.sort() writer._fieldnames.sort() self.assertListEqual(writer._fieldnames, fieldnames) - self.assertListEqual(writer._inspector['messages'], messages) + self.assertListEqual(writer._inspector["messages"], messages) self.assertDictEqual( - dict(k_v for k_v in writer._inspector.items() if k_v[0].startswith('metric.')), - dict(('metric.' + k_v1[0], k_v1[1]) for k_v1 in metrics.items())) + dict( + k_v for k_v in writer._inspector.items() if k_v[0].startswith("metric.") + ), + dict(("metric." + k_v1[0], k_v1[1]) for k_v1 in metrics.items()), + ) writer.flush(finished=True) @@ -181,13 +216,13 @@ def test_record_writer_with_random_data(self, save_recording=False): self.assertEqual(writer._record_count, 0) self.assertEqual(writer.pending_record_count, 0) self.assertEqual(writer._buffer.tell(), 0) - self.assertEqual(writer._buffer.getvalue(), '') + self.assertEqual(writer._buffer.getvalue(), "") self.assertEqual(writer._total_record_count, 31) self.assertEqual(writer.committed_record_count, 31) self.assertRaises(AssertionError, writer.flush, finished=True, partial=True) - self.assertRaises(AssertionError, writer.flush, finished='non-boolean') - self.assertRaises(AssertionError, writer.flush, partial='non-boolean') + self.assertRaises(AssertionError, writer.flush, finished="non-boolean") + self.assertRaises(AssertionError, writer.flush, partial="non-boolean") self.assertRaises(AssertionError, writer.flush) # P2 [ ] TODO: For SCPv2 we should follow the finish negotiation protocol. @@ -208,21 +243,28 @@ def _compare_chunks(self, chunks_1, chunks_2): n = 0 for chunk_1, chunk_2 in zip(chunks_1, chunks_2): self.assertDictEqual( - chunk_1.metadata, chunk_2.metadata, - 'Chunk {0}: metadata error: "{1}" != "{2}"'.format(n, chunk_1.metadata, chunk_2.metadata)) - self.assertMultiLineEqual(chunk_1.body, chunk_2.body, 'Chunk {0}: data error'.format(n)) + chunk_1.metadata, + chunk_2.metadata, + 'Chunk {0}: metadata error: "{1}" != "{2}"'.format( + n, chunk_1.metadata, chunk_2.metadata + ), + ) + self.assertMultiLineEqual( + chunk_1.body, chunk_2.body, "Chunk {0}: data error".format(n) + ) n += 1 def _load_chunks(self, ifile): import re - pattern = re.compile(r'chunked 1.0,(?P\d+),(?P\d+)\n') + pattern = re.compile( + r"chunked 1.0,(?P\d+),(?P\d+)\n" + ) decoder = json.JSONDecoder() chunks = [] while True: - line = ifile.readline() if len(line) == 0: @@ -231,67 +273,60 @@ def _load_chunks(self, ifile): match = pattern.match(line) self.assertIsNotNone(match) - metadata_length = int(match.group('metadata_length')) + metadata_length = int(match.group("metadata_length")) metadata = ifile.read(metadata_length) metadata = decoder.decode(metadata) - body_length = int(match.group('body_length')) - body = ifile.read(body_length) if body_length > 0 else '' + body_length = int(match.group("body_length")) + body = ifile.read(body_length) if body_length > 0 else "" chunks.append(TestInternals._Chunk(metadata, body)) return chunks - _Chunk = namedtuple('Chunk', ('metadata', 'body')) + _Chunk = namedtuple("Chunk", ("metadata", "body")) _dictionary = { - 'a': 1, - 'b': 2, - 'c': { - 'd': 3, - 'e': 4, - 'f': { - 'g': 5, - 'h': 6, - 'i': 7 - }, - 'j': 8, - 'k': 9 - }, - 'l': 10, - 'm': 11, - 'n': 12 + "a": 1, + "b": 2, + "c": {"d": 3, "e": 4, "f": {"g": 5, "h": 6, "i": 7}, "j": 8, "k": 9}, + "l": 10, + "m": 11, + "n": 12, } - _json_input = str(json.dumps(_dictionary, separators=(',', ':'))) + _json_input = str(json.dumps(_dictionary, separators=(",", ":"))) _package_path = os.path.dirname(os.path.abspath(__file__)) class TestRecorder: - def __init__(self, test_case): - self._test_case = test_case self._output = None self._recording = None self._recording_part = None def _not_implemented(self): - raise NotImplementedError('class {} is not in playback or record mode'.format(self.__class__.__name__)) + raise NotImplementedError( + "class {} is not in playback or record mode".format( + self.__class__.__name__ + ) + ) - self.get = self.next_part = self.stop = MethodType(_not_implemented, self, self.__class__) + self.get = self.next_part = self.stop = MethodType( + _not_implemented, self, self.__class__ + ) @property def output(self): return self._output def playback(self, path): - - with open(path, 'rb') as f: + with open(path, "rb") as f: test_data = pickle.load(f) self._output = BytesIO() - self._recording = test_data['inputs'] + self._recording = test_data["inputs"] self._recording_part = self._recording.popleft() def get(self, method, *args, **kwargs): @@ -305,12 +340,11 @@ def next_part(self): self.next_part = MethodType(next_part, self, self.__class__) def stop(self): - self._test_case.assertEqual(test_data['results'], self._output.getvalue()) + self._test_case.assertEqual(test_data["results"], self._output.getvalue()) self.stop = MethodType(stop, self, self.__class__) def record(self, path): - self._output = BytesIO() self._recording = deque() self._recording_part = OrderedDict() @@ -337,15 +371,16 @@ def next_part(self): self.next_part = MethodType(next_part, self, self.__class__) def stop(self): - with io.open(path, 'wb') as f: - test = OrderedDict((('inputs', self._recording), ('results', self._output.getvalue()))) + with io.open(path, "wb") as f: + test = OrderedDict( + (("inputs", self._recording), ("results", self._output.getvalue())) + ) pickle.dump(test, f) self.stop = MethodType(stop, self, self.__class__) def recorded(method): - @wraps(method) def _record(*args, **kwargs): return args[0].recorder.get(method, *args, **kwargs) @@ -354,12 +389,12 @@ def _record(*args, **kwargs): class Test: - def __init__(self, fieldnames, data_generators): - TestCase.__init__(self) - self._data_generators = list(chain((lambda: self._serial_number, time), data_generators)) - self._fieldnames = list(chain(('_serial', '_time'), fieldnames)) + self._data_generators = list( + chain((lambda: self._serial_number, time), data_generators) + ) + self._fieldnames = list(chain(("_serial", "_time"), fieldnames)) self._recorder = TestRecorder(self) self._serial_number = None @@ -382,12 +417,16 @@ def serial_number(self): return self._serial_number def playback(self): - self.recorder.playback(os.path.join(TestInternals._package_path, 'TestRecorder.recording')) + self.recorder.playback( + os.path.join(TestInternals._package_path, "TestRecorder.recording") + ) self._run() self.recorder.stop() def record(self): - self.recorder.record(os.path.join(TestInternals._package_path, 'TestRecorder.recording')) + self.recorder.record( + os.path.join(TestInternals._package_path, "TestRecorder.recording") + ) self._run() self.recorder.stop() @@ -395,7 +434,6 @@ def runTest(self): pass # We'll adopt the new test recording mechanism a little later def _run(self): - writer = RecordWriterV2(self.recorder.output, maxresultrows=10) write_record = writer.write_record names = self.fieldnames diff --git a/tests/searchcommands/test_multibyte_processing.py b/tests/searchcommands/test_multibyte_processing.py index 1d021eed7..55f7b4b86 100644 --- a/tests/searchcommands/test_multibyte_processing.py +++ b/tests/searchcommands/test_multibyte_processing.py @@ -19,7 +19,8 @@ def stream(self, records): def get_input_file(name): return path.join( - path.dirname(path.dirname(__file__)), 'data', 'custom_search', name + '.gz') + path.dirname(path.dirname(__file__)), "data", "custom_search", name + ".gz" + ) def test_multibyte_chunked(): diff --git a/tests/searchcommands/test_reporting_command.py b/tests/searchcommands/test_reporting_command.py index dbda9cd8c..b91d0d96f 100644 --- a/tests/searchcommands/test_reporting_command.py +++ b/tests/searchcommands/test_reporting_command.py @@ -11,7 +11,7 @@ def reduce(self, records): value = 0 for record in records: value += int(record["value"]) - yield {'sum': value} + yield {"sum": value} cmd = TestReportingCommand() ifile = io.BytesIO() @@ -26,12 +26,12 @@ def reduce(self, records): ofile.seek(0) chunk_stream = chunky.ChunkedDataStream(ofile) getinfo_response = chunk_stream.read_chunk() - assert getinfo_response.meta['type'] == 'reporting' + assert getinfo_response.meta["type"] == "reporting" data_chunk = chunk_stream.read_chunk() - assert data_chunk.meta['finished'] is True # Should only be one row + assert data_chunk.meta["finished"] is True # Should only be one row data = list(data_chunk.data) assert len(data) == 1 - assert int(data[0]['sum']) == sum(range(0, 10)) + assert int(data[0]["sum"]) == sum(range(0, 10)) def test_simple_reporting_command_with_map(): @@ -66,7 +66,7 @@ def reduce(self, records): chunk_stream = chunky.ChunkedDataStream(ofile) chunk_stream.read_chunk() data_chunk = chunk_stream.read_chunk() - assert data_chunk.meta['finished'] is True + assert data_chunk.meta["finished"] is True result = list(data_chunk.data) expected_sum = sum(i * 2 for i in range(5)) diff --git a/tests/searchcommands/test_search_command.py b/tests/searchcommands/test_search_command.py index 849a8888b..7e1542107 100755 --- a/tests/searchcommands/test_search_command.py +++ b/tests/searchcommands/test_search_command.py @@ -41,8 +41,10 @@ def build_command_input(getinfo_metadata, execute_metadata, execute_body): - input = (f'chunked 1.0,{len(ensure_binary(getinfo_metadata))},0\n{getinfo_metadata}' + - f'chunked 1.0,{len(ensure_binary(execute_metadata))},{len(ensure_binary(execute_body))}\n{execute_metadata}{execute_body}') + input = ( + f"chunked 1.0,{len(ensure_binary(getinfo_metadata))},0\n{getinfo_metadata}" + + f"chunked 1.0,{len(ensure_binary(execute_metadata))},{len(ensure_binary(execute_body))}\n{execute_metadata}{execute_body}" + ) ifile = BytesIO(ensure_binary(input)) @@ -58,7 +60,7 @@ class TestCommand(SearchCommand): def echo(self, records): for record in records: - if record.get('action') == 'raise_exception': + if record.get("action") == "raise_exception": raise Exception(self) yield record @@ -66,7 +68,6 @@ def _execute(self, ifile, process): SearchCommand._execute(self, ifile, self.echo) class ConfigurationSettings(SearchCommand.ConfigurationSettings): - # region SCP v1/v2 properties generating = ConfigurationSetting() @@ -101,11 +102,15 @@ class TestStreamingCommand(StreamingCommand): def stream(self, records): serial_number = 0 for record in records: - action = record['action'] - if action == 'raise_error': - raise RuntimeError('Testing') - value = self.search_results_info if action == 'get_search_results_info' else None - yield {'_serial': serial_number, 'data': value} + action = record["action"] + if action == "raise_error": + raise RuntimeError("Testing") + value = ( + self.search_results_info + if action == "get_search_results_info" + else None + ) + yield {"_serial": serial_number, "data": value} serial_number += 1 @@ -115,13 +120,12 @@ def setUp(self): TestCase.setUp(self) def test_process_scpv2(self): - # SearchCommand.process should # 1. Recognize all standard options: metadata = ( - '{{' + "{{" '"action": "getinfo", "preview": false, "searchinfo": {{' '"latest_time": "0",' '"splunk_version": "20150522",' @@ -134,7 +138,7 @@ def test_process_scpv2(self): '"show_configuration={show_configuration}",' '"required_option_1=value_1",' '"required_option_2=value_2"' - '],' + "]," '"search": "A%7C%20inputlookup%20tweets%20%7C%20countmatches%20fieldname%3Dword_count%20pattern%3D%22%5Cw%2B%22%20text%20record%3Dt%20%7C%20export%20add_timestamp%3Df%20add_offset%3Dt%20format%3Dcsv%20segmentation%3Draw",' '"earliest_time": "0",' '"session_key": "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^",' @@ -149,16 +153,19 @@ def test_process_scpv2(self): '"show_configuration={show_configuration}",' '"required_option_1=value_1",' '"required_option_2=value_2"' - '],' + "]," '"maxresultrows": 10,' '"command": "countmatches"' - '}}' - '}}') + "}}" + "}}" + ) basedir = self._package_directory - logging_configuration = os.path.join(basedir, 'apps', 'app_with_logging_configuration', 'logging.conf') - logging_level = 'ERROR' + logging_configuration = os.path.join( + basedir, "apps", "app_with_logging_configuration", "logging.conf" + ) + logging_level = "ERROR" record = False show_configuration = True @@ -166,19 +173,20 @@ def test_process_scpv2(self): dispatch_dir=encode_string(""), logging_configuration=encode_string(logging_configuration)[1:-1], logging_level=logging_level, - record=('true' if record is True else 'false'), - show_configuration=('true' if show_configuration is True else 'false')) + record=("true" if record is True else "false"), + show_configuration=("true" if show_configuration is True else "false"), + ) execute_metadata = '{"action":"execute","finished":true}' - execute_body = 'test\r\ndata\r\n测试\r\n' + execute_body = "test\r\ndata\r\n测试\r\n" ifile = build_command_input(getinfo_metadata, execute_metadata, execute_body) command = TestCommand() result = BytesIO() - argv = ['some-external-search-command.py'] + argv = ["some-external-search-command.py"] - self.assertEqual(command.logging_level, 'WARNING') + self.assertEqual(command.logging_level, "WARNING") self.assertIs(command.record, None) self.assertIs(command.show_configuration, None) @@ -186,27 +194,27 @@ def test_process_scpv2(self): # noinspection PyTypeChecker command.process(argv, ifile, ofile=result) except SystemExit as error: - self.fail('Unexpected exception: {}: {}'.format(type(error).__name__, error)) + self.fail( + "Unexpected exception: {}: {}".format(type(error).__name__, error) + ) self.assertEqual(command.logging_configuration, logging_configuration) - self.assertEqual(command.logging_level, 'ERROR') + self.assertEqual(command.logging_level, "ERROR") self.assertEqual(command.record, record) self.assertEqual(command.show_configuration, show_configuration) - self.assertEqual(command.required_option_1, 'value_1') - self.assertEqual(command.required_option_2, 'value_2') + self.assertEqual(command.required_option_1, "value_1") + self.assertEqual(command.required_option_2, "value_2") expected = ( - 'chunked 1.0,68,0\n' + "chunked 1.0,68,0\n" '{"inspector":{"messages":[["INFO","test command configuration: "]]}}' - 'chunked 1.0,17,32\n' + "chunked 1.0,17,32\n" '{"finished":true}test,__mv_test\r\n' - 'data,\r\n' - '测试,\r\n' + "data,\r\n" + "测试,\r\n" ) - self.assertEqual( - expected, - result.getvalue().decode('utf-8')) + self.assertEqual(expected, result.getvalue().decode("utf-8")) self.assertEqual(command.protocol_version, 2) @@ -222,43 +230,68 @@ def test_process_scpv2(self): command_metadata = command.metadata input_header = command.input_header - self.assertIsNone(input_header['allowStream']) - self.assertEqual(input_header['infoPath'], os.path.join(command_metadata.searchinfo.dispatch_dir, 'info.csv')) - self.assertIsNone(input_header['keywords']) - self.assertEqual(input_header['preview'], command_metadata.preview) - self.assertIs(input_header['realtime'], False) - self.assertEqual(input_header['search'], command_metadata.searchinfo.search) - self.assertEqual(input_header['sid'], command_metadata.searchinfo.sid) - self.assertEqual(input_header['splunkVersion'], command_metadata.searchinfo.splunk_version) - self.assertIsNone(input_header['truncated']) - - self.assertEqual(command_metadata.preview, input_header['preview']) - self.assertEqual(command_metadata.searchinfo.app, 'searchcommands_app') - self.assertEqual(command_metadata.searchinfo.args, - ['logging_configuration=' + logging_configuration, 'logging_level=ERROR', 'record=false', - 'show_configuration=true', 'required_option_1=value_1', 'required_option_2=value_2']) - self.assertEqual(command_metadata.searchinfo.dispatch_dir, os.path.dirname(input_header['infoPath'])) + self.assertIsNone(input_header["allowStream"]) + self.assertEqual( + input_header["infoPath"], + os.path.join(command_metadata.searchinfo.dispatch_dir, "info.csv"), + ) + self.assertIsNone(input_header["keywords"]) + self.assertEqual(input_header["preview"], command_metadata.preview) + self.assertIs(input_header["realtime"], False) + self.assertEqual(input_header["search"], command_metadata.searchinfo.search) + self.assertEqual(input_header["sid"], command_metadata.searchinfo.sid) + self.assertEqual( + input_header["splunkVersion"], command_metadata.searchinfo.splunk_version + ) + self.assertIsNone(input_header["truncated"]) + + self.assertEqual(command_metadata.preview, input_header["preview"]) + self.assertEqual(command_metadata.searchinfo.app, "searchcommands_app") + self.assertEqual( + command_metadata.searchinfo.args, + [ + "logging_configuration=" + logging_configuration, + "logging_level=ERROR", + "record=false", + "show_configuration=true", + "required_option_1=value_1", + "required_option_2=value_2", + ], + ) + self.assertEqual( + command_metadata.searchinfo.dispatch_dir, + os.path.dirname(input_header["infoPath"]), + ) self.assertEqual(command_metadata.searchinfo.earliest_time, 0.0) self.assertEqual(command_metadata.searchinfo.latest_time, 0.0) - self.assertEqual(command_metadata.searchinfo.owner, 'admin') - self.assertEqual(command_metadata.searchinfo.raw_args, command_metadata.searchinfo.args) - self.assertEqual(command_metadata.searchinfo.search, - 'A| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw') - self.assertEqual(command_metadata.searchinfo.session_key, - '0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^') - self.assertEqual(command_metadata.searchinfo.sid, '1433261372.158') - self.assertEqual(command_metadata.searchinfo.splunk_version, '20150522') - self.assertEqual(command_metadata.searchinfo.splunkd_uri, 'https://127.0.0.1:8089') - self.assertEqual(command_metadata.searchinfo.username, 'admin') + self.assertEqual(command_metadata.searchinfo.owner, "admin") + self.assertEqual( + command_metadata.searchinfo.raw_args, command_metadata.searchinfo.args + ) + self.assertEqual( + command_metadata.searchinfo.search, + 'A| inputlookup tweets | countmatches fieldname=word_count pattern="\\w+" text record=t | export add_timestamp=f add_offset=t format=csv segmentation=raw', + ) + self.assertEqual( + command_metadata.searchinfo.session_key, + "0JbG1fJEvXrL6iYZw9y7tmvd6nHjTKj7ggaE7a4Jv5R0UIbeYJ65kThn^3hiNeoqzMT_LOtLpVR3Y8TIJyr5bkHUElMijYZ8l14wU0L4n^Oa5QxepsZNUIIQCBm^", + ) + self.assertEqual(command_metadata.searchinfo.sid, "1433261372.158") + self.assertEqual(command_metadata.searchinfo.splunk_version, "20150522") + self.assertEqual( + command_metadata.searchinfo.splunkd_uri, "https://127.0.0.1:8089" + ) + self.assertEqual(command_metadata.searchinfo.username, "admin") self.assertEqual(command_metadata.searchinfo.maxresultrows, 10) - self.assertEqual(command_metadata.searchinfo.command, 'countmatches') - + self.assertEqual(command_metadata.searchinfo.command, "countmatches") self.maxDiff = None self.assertIsInstance(command.service, Service) - self.assertEqual(command.service.authority, command_metadata.searchinfo.splunkd_uri) + self.assertEqual( + command.service.authority, command_metadata.searchinfo.splunkd_uri + ) self.assertEqual(command.service.token, command_metadata.searchinfo.session_key) self.assertEqual(command.service.namespace.app, command.metadata.searchinfo.app) self.assertIsNone(command.service.namespace.owner) @@ -268,6 +301,7 @@ def test_process_scpv2(self): _package_directory = os.path.dirname(os.path.abspath(__file__)) + class TestSearchCommandService(TestCase): def setUp(self): TestCase.setUp(self) @@ -284,36 +318,54 @@ def test_service_not_exists(self): self.assertIsNone(self.command.service) def test_missing_metadata(self): - with self.assertLogs(self.command.logger, level='WARNING') as log: + with self.assertLogs(self.command.logger, level="WARNING") as log: service = self.command.service self.assertIsNone(service) - self.assertTrue(any("Missing metadata for service creation." in message for message in log.output)) + self.assertTrue( + any( + "Missing metadata for service creation." in message + for message in log.output + ) + ) def test_missing_searchinfo(self): - with self.assertLogs(self.command.logger, level='WARNING') as log: + with self.assertLogs(self.command.logger, level="WARNING") as log: self.command._metadata = ObjectView({}) self.assertIsNone(self.command.service) - self.assertTrue(any("Missing searchinfo in metadata for service creation." in message for message in log.output)) - + self.assertTrue( + any( + "Missing searchinfo in metadata for service creation." in message + for message in log.output + ) + ) def test_missing_splunkd_uri(self): - with self.assertLogs(self.command.logger, level='WARNING') as log: + with self.assertLogs(self.command.logger, level="WARNING") as log: metadata = ObjectView({"searchinfo": ObjectView({"splunkd_uri": ""})}) self.command._metadata = metadata self.assertIsNone(self.command.service) - self.assertTrue(any("Incorrect value for Splunkd URI: '' in metadata" in message for message in log.output)) - - + self.assertTrue( + any( + "Incorrect value for Splunkd URI: '' in metadata" in message + for message in log.output + ) + ) def test_service_returns_valid_service_object(self): - metadata = ObjectView({"searchinfo":ObjectView({"splunkd_uri":"https://127.0.0.1:8089", - "session_key":"mock_session_key", - "app":"search", - })}) + metadata = ObjectView( + { + "searchinfo": ObjectView( + { + "splunkd_uri": "https://127.0.0.1:8089", + "session_key": "mock_session_key", + "app": "search", + } + ) + } + ) self.command._metadata = metadata self.assertIsInstance(self.command.service, Service) - if __name__ == "__main__": main() diff --git a/tests/searchcommands/test_streaming_command.py b/tests/searchcommands/test_streaming_command.py index afb2e8caa..e732d3be8 100644 --- a/tests/searchcommands/test_streaming_command.py +++ b/tests/searchcommands/test_streaming_command.py @@ -7,7 +7,6 @@ def test_simple_streaming_command(): @Configuration() class TestStreamingCommand(StreamingCommand): - def stream(self, records): for record in records: record["out_index"] = record["in_index"] @@ -32,7 +31,6 @@ def stream(self, records): def test_field_preservation_negative(): @Configuration() class TestStreamingCommand(StreamingCommand): - def stream(self, records): for index, record in enumerate(records): if index % 2 != 0: @@ -66,7 +64,6 @@ def stream(self, records): def test_field_preservation_positive(): @Configuration() class TestStreamingCommand(StreamingCommand): - def stream(self, records): for index, record in enumerate(records): if index % 2 != 0: diff --git a/tests/searchcommands/test_validators.py b/tests/searchcommands/test_validators.py index 80149aa64..62e6fcc93 100755 --- a/tests/searchcommands/test_validators.py +++ b/tests/searchcommands/test_validators.py @@ -27,20 +27,24 @@ # P2 [ ] TODO: Verify that all format methods produce 'None' when value is None + @pytest.mark.smoke class TestValidators(TestCase): - def setUp(self): TestCase.setUp(self) def test_boolean(self): - truth_values = { - '1': True, '0': False, - 't': True, 'f': False, - 'true': True, 'false': False, - 'y': True, 'n': False, - 'yes': True, 'no': False + "1": True, + "0": False, + "t": True, + "f": False, + "true": True, + "false": False, + "y": True, + "n": False, + "yes": True, + "no": False, } validator = validators.Boolean() @@ -51,10 +55,9 @@ def test_boolean(self): self.assertEqual(validator.__call__(s), truth_values[value]) self.assertIsNone(validator.__call__(None)) - self.assertRaises(ValueError, validator.__call__, 'anything-else') + self.assertRaises(ValueError, validator.__call__, "anything-else") def test_duration(self): - # Duration validator should parse and format time intervals of the form # HH:MM:SS @@ -64,53 +67,52 @@ def test_duration(self): value = str(seconds) self.assertEqual(validator(value), seconds) self.assertEqual(validator(validator.format(seconds)), seconds) - value = '%d:%02d' % (seconds / 60, seconds % 60) + value = "%d:%02d" % (seconds / 60, seconds % 60) self.assertEqual(validator(value), seconds) self.assertEqual(validator(validator.format(seconds)), seconds) - value = '%d:%02d:%02d' % (seconds / 3600, (seconds / 60) % 60, seconds % 60) + value = "%d:%02d:%02d" % (seconds / 3600, (seconds / 60) % 60, seconds % 60) self.assertEqual(validator(value), seconds) self.assertEqual(validator(validator.format(seconds)), seconds) - self.assertEqual(validator('230:00:00'), 230 * 60 * 60) - self.assertEqual(validator('23:00:00'), 23 * 60 * 60) - self.assertEqual(validator('00:59:00'), 59 * 60) - self.assertEqual(validator('00:00:59'), 59) - - self.assertEqual(validator.format(230 * 60 * 60), '230:00:00') - self.assertEqual(validator.format(23 * 60 * 60), '23:00:00') - self.assertEqual(validator.format(59 * 60), '00:59:00') - self.assertEqual(validator.format(59), '00:00:59') - - self.assertRaises(ValueError, validator, '-1') - self.assertRaises(ValueError, validator, '00:-1') - self.assertRaises(ValueError, validator, '-1:00') - self.assertRaises(ValueError, validator, '00:00:-1') - self.assertRaises(ValueError, validator, '00:-1:00') - self.assertRaises(ValueError, validator, '-1:00:00') - self.assertRaises(ValueError, validator, '00:00:60') - self.assertRaises(ValueError, validator, '00:60:00') + self.assertEqual(validator("230:00:00"), 230 * 60 * 60) + self.assertEqual(validator("23:00:00"), 23 * 60 * 60) + self.assertEqual(validator("00:59:00"), 59 * 60) + self.assertEqual(validator("00:00:59"), 59) + + self.assertEqual(validator.format(230 * 60 * 60), "230:00:00") + self.assertEqual(validator.format(23 * 60 * 60), "23:00:00") + self.assertEqual(validator.format(59 * 60), "00:59:00") + self.assertEqual(validator.format(59), "00:00:59") + + self.assertRaises(ValueError, validator, "-1") + self.assertRaises(ValueError, validator, "00:-1") + self.assertRaises(ValueError, validator, "-1:00") + self.assertRaises(ValueError, validator, "00:00:-1") + self.assertRaises(ValueError, validator, "00:-1:00") + self.assertRaises(ValueError, validator, "-1:00:00") + self.assertRaises(ValueError, validator, "00:00:60") + self.assertRaises(ValueError, validator, "00:60:00") def test_fieldname(self): pass def test_file(self): - # Create a file on $SPLUNK_HOME/var/run/splunk - file_name = 'TestValidators.test_file' + file_name = "TestValidators.test_file" tempdir = tempfile.gettempdir() full_path = os.path.join(tempdir, file_name) try: - validator = validators.File(mode='w', buffering=4096, directory=tempdir) + validator = validators.File(mode="w", buffering=4096, directory=tempdir) with validator(file_name) as f: - f.write('some text') + f.write("some text") - validator = validators.File(mode='a', directory=tempdir) + validator = validators.File(mode="a", directory=tempdir) with validator(full_path) as f: - f.write('\nmore text') + f.write("\nmore text") # Verify that you can read the file from a file using an absolute or relative path @@ -118,7 +120,7 @@ def test_file(self): for path in file_name, full_path: with validator(path) as f: - self.assertEqual(f.read(), 'some text\nmore text') + self.assertEqual(f.read(), "some text\nmore text") self.assertEqual(f.name, full_path) # Verify that a ValueError is raised, if the file does not exist @@ -132,7 +134,6 @@ def test_file(self): os.unlink(full_path) def test_integer(self): - # Point of interest: # # On all *nix operating systems an int is 32-bits long on 32-bit systems and 64-bits long on 64-bit systems so @@ -216,8 +217,8 @@ def test(float_val): test(0.0001) test(100101.011) test(2 * maxsize) - test('18.32123') - self.assertRaises(ValueError, validator.__call__, 'Splunk!') + test("18.32123") + self.assertRaises(ValueError, validator.__call__, "Splunk!") validator = validators.Float(minimum=0) self.assertEqual(validator.__call__(0), 0) @@ -245,42 +246,39 @@ def test(float_val): self.assertRaises(ValueError, validator.__call__, maxsize + 1) def test_list(self): - validator = validators.List() - self.assertEqual(validator.__call__(''), []) - self.assertEqual(validator.__call__('a,b,c'), ['a', 'b', 'c']) + self.assertEqual(validator.__call__(""), []) + self.assertEqual(validator.__call__("a,b,c"), ["a", "b", "c"]) self.assertRaises(ValueError, validator.__call__, '"a,b,c') self.assertEqual(validator.__call__([]), []) self.assertEqual(validator.__call__(None), None) validator = validators.List(validators.Integer(1, 10)) - self.assertEqual(validator.__call__(''), []) - self.assertEqual(validator.__call__('1,2,3'), [1,2,3]) - self.assertRaises(ValueError, validator.__call__, '1,2,0') + self.assertEqual(validator.__call__(""), []) + self.assertEqual(validator.__call__("1,2,3"), [1, 2, 3]) + self.assertRaises(ValueError, validator.__call__, "1,2,0") self.assertEqual(validator.__call__([]), []) self.assertEqual(validator.__call__(None), None) def test_map(self): - validator = validators.Map(a=1, b=2, c=3) - self.assertEqual(validator.__call__('a'), 1) - self.assertEqual(validator.__call__('b'), 2) - self.assertEqual(validator.__call__('c'), 3) - self.assertRaises(ValueError, validator.__call__, 'd') + self.assertEqual(validator.__call__("a"), 1) + self.assertEqual(validator.__call__("b"), 2) + self.assertEqual(validator.__call__("c"), 3) + self.assertRaises(ValueError, validator.__call__, "d") self.assertEqual(validator.__call__(None), None) def test_match(self): - - validator = validators.Match('social security number', r'\d{3}-\d{2}-\d{4}') - self.assertEqual(validator.__call__('123-45-6789'), '123-45-6789') - self.assertRaises(ValueError, validator.__call__, 'foo') + validator = validators.Match("social security number", r"\d{3}-\d{2}-\d{4}") + self.assertEqual(validator.__call__("123-45-6789"), "123-45-6789") + self.assertRaises(ValueError, validator.__call__, "foo") self.assertEqual(validator.__call__(None), None) self.assertEqual(validator.format(None), None) - self.assertEqual(validator.format('123-45-6789'), '123-45-6789') + self.assertEqual(validator.format("123-45-6789"), "123-45-6789") def test_option_name(self): pass @@ -289,19 +287,18 @@ def test_regular_expression(self): validator = validators.RegularExpression() # duck-type: act like it's a regex and allow failure if it isn't one - validator.__call__('a').match('a') + validator.__call__("a").match("a") self.assertEqual(validator.__call__(None), None) - self.assertRaises(ValueError, validator.__call__, '(a') + self.assertRaises(ValueError, validator.__call__, "(a") def test_set(self): - - validator = validators.Set('a', 'b', 'c') - self.assertEqual(validator.__call__('a'), 'a') - self.assertEqual(validator.__call__('b'), 'b') - self.assertEqual(validator.__call__('c'), 'c') + validator = validators.Set("a", "b", "c") + self.assertEqual(validator.__call__("a"), "a") + self.assertEqual(validator.__call__("b"), "b") + self.assertEqual(validator.__call__("c"), "c") self.assertEqual(validator.__call__(None), None) - self.assertRaises(ValueError, validator.__call__, 'd') + self.assertRaises(ValueError, validator.__call__, "d") if __name__ == "__main__": diff --git a/tests/test_all.py b/tests/test_all.py index 55b4d77f5..5c42e18eb 100755 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -20,7 +20,7 @@ import unittest os.chdir(os.path.dirname(os.path.abspath(__file__))) -suite = unittest.defaultTestLoader.discover('.') +suite = unittest.defaultTestLoader.discover(".") -if __name__ == '__main__': +if __name__ == "__main__": unittest.TextTestRunner().run(suite) diff --git a/tests/test_app.py b/tests/test_app.py index 35be38146..a6194290a 100755 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -27,7 +27,7 @@ def setUp(self): super().setUp() if self.app is None: for app in self.service.apps: - if app.name.startswith('delete-me'): + if app.name.startswith("delete-me"): self.service.apps.delete(app.name) # Creating apps takes 0.8s, which is too long to wait for # each test in this test suite. Therefore we create one @@ -46,10 +46,10 @@ def tearDown(self): # The rest of this will leave Splunk in a state requiring a restart. # It doesn't actually matter, though. self.service = client.connect(**self.opts.kwargs) - app_name = '' + app_name = "" for app in self.service.apps: app_name = app.name - if app_name.startswith('delete-me'): + if app_name.startswith("delete-me"): self.service.apps.delete(app_name) self.assertEventuallyTrue(lambda: app_name not in self.service.apps) self.clear_restart_message() @@ -57,30 +57,30 @@ def tearDown(self): def test_app_integrity(self): self.check_entity(self.app) self.app.setupInfo - self.app['setupInfo'] + self.app["setupInfo"] def test_disable_enable(self): self.app.disable() self.app.refresh() - self.assertEqual(self.app['disabled'], '1') + self.assertEqual(self.app["disabled"], "1") self.app.enable() self.app.refresh() - self.assertEqual(self.app['disabled'], '0') + self.assertEqual(self.app["disabled"], "0") def test_update(self): kwargs = { - 'author': "Me", - 'description': "Test app description", - 'label': "SDK Test", - 'version': "1.2", - 'visible': True, + "author": "Me", + "description": "Test app description", + "label": "SDK Test", + "version": "1.2", + "visible": True, } self.app.update(**kwargs) self.app.refresh() - self.assertEqual(self.app['author'], "Me") - self.assertEqual(self.app['label'], "SDK Test") - self.assertEqual(self.app['version'], "1.2") - self.assertEqual(self.app['visible'], "1") + self.assertEqual(self.app["author"], "Me") + self.assertEqual(self.app["label"], "SDK Test") + self.assertEqual(self.app["version"], "1.2") + self.assertEqual(self.app["visible"], "1") def test_delete(self): name = testlib.tmpname() @@ -93,7 +93,7 @@ def test_delete(self): def test_package(self): p = self.app.package() self.assertEqual(p.name, self.app_name) - self.assertTrue(p.path.endswith(self.app_name + '.spl')) + self.assertTrue(p.path.endswith(self.app_name + ".spl")) # Assert string due to deprecation of this property in new Splunk versions self.assertIsInstance(p.url, str) @@ -104,4 +104,5 @@ def test_updateInfo(self): if __name__ == "__main__": import unittest + unittest.main() diff --git a/tests/test_binding.py b/tests/test_binding.py index 0db37b01d..4a43c2b05 100755 --- a/tests/test_binding.py +++ b/tests/test_binding.py @@ -100,7 +100,7 @@ def test_read_partial(self): self.assertFalse(response.empty) self.assertEqual(response.read(), txt) self.assertTrue(response.empty) - self.assertEqual(response.read(), b'') + self.assertEqual(response.read(), b"") def test_readable(self): txt = "abcd" @@ -139,65 +139,56 @@ def test_readinto_memoryview(self): class TestUrlEncoded(BindingTestCase): def test_idempotent(self): - a = UrlEncoded('abc') + a = UrlEncoded("abc") self.assertEqual(a, UrlEncoded(a)) def test_append(self): - self.assertEqual(UrlEncoded('a') + UrlEncoded('b'), - UrlEncoded('ab')) + self.assertEqual(UrlEncoded("a") + UrlEncoded("b"), UrlEncoded("ab")) def test_append_string(self): - self.assertEqual(UrlEncoded('a') + '%', - UrlEncoded('a%')) + self.assertEqual(UrlEncoded("a") + "%", UrlEncoded("a%")) def test_append_to_string(self): - self.assertEqual('%' + UrlEncoded('a'), - UrlEncoded('%a')) + self.assertEqual("%" + UrlEncoded("a"), UrlEncoded("%a")) def test_interpolation_fails(self): - self.assertRaises(TypeError, lambda: UrlEncoded('%s') % 'boris') + self.assertRaises(TypeError, lambda: UrlEncoded("%s") % "boris") def test_chars(self): - for char, code in [(' ', '%20'), - ('"', '%22'), - ('%', '%25')]: - self.assertEqual(UrlEncoded(char), - UrlEncoded(code, skip_encode=True)) + for char, code in [(" ", "%20"), ('"', "%22"), ("%", "%25")]: + self.assertEqual(UrlEncoded(char), UrlEncoded(code, skip_encode=True)) def test_repr(self): - self.assertEqual(repr(UrlEncoded('% %')), "UrlEncoded('% %')") + self.assertEqual(repr(UrlEncoded("% %")), "UrlEncoded('% %')") class TestAuthority(unittest.TestCase): def test_authority_default(self): - self.assertEqual(binding._authority(), - "https://localhost:8089") + self.assertEqual(binding._authority(), "https://localhost:8089") def test_ipv4_host(self): self.assertEqual( - binding._authority( - host="splunk.utopia.net"), - "https://splunk.utopia.net:8089") + binding._authority(host="splunk.utopia.net"), + "https://splunk.utopia.net:8089", + ) def test_ipv6_host(self): self.assertEqual( - binding._authority( - host="2001:0db8:85a3:0000:0000:8a2e:0370:7334"), - "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8089") + binding._authority(host="2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8089", + ) def test_ipv6_host_enclosed(self): self.assertEqual( - binding._authority( - host="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"), - "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8089") + binding._authority(host="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"), + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8089", + ) def test_all_fields(self): self.assertEqual( - binding._authority( - scheme="http", - host="splunk.utopia.net", - port="471"), - "http://splunk.utopia.net:471") + binding._authority(scheme="http", host="splunk.utopia.net", port="471"), + "http://splunk.utopia.net:471", + ) class TestUserManipulation(BindingTestCase): @@ -223,15 +214,18 @@ def tearDown(self): raise def test_user_without_role_fails(self): - self.assertRaises(binding.HTTPError, - self.context.post, - PATH_USERS, name=self.username, - password=self.password) + self.assertRaises( + binding.HTTPError, + self.context.post, + PATH_USERS, + name=self.username, + password=self.password, + ) def test_create_user(self): response = self.context.post( - PATH_USERS, name=self.username, - password=self.password, roles=self.roles) + PATH_USERS, name=self.username, password=self.password, roles=self.roles + ) self.assertEqual(response.status, 201) response = self.context.get(PATH_USERS + self.username) @@ -246,7 +240,8 @@ def test_update_user(self): roles=self.roles, defaultApp="search", realname="Renzo", - email="email.me@now.com") + email="email.me@now.com", + ) self.assertEqual(response.status, 200) response = self.context.get(PATH_USERS + self.username) @@ -266,12 +261,13 @@ def test_post_with_body_behaves(self): self.assertEqual(response.status, 200) def test_post_with_get_arguments_to_receivers_stream(self): - text = 'Hello, world!' + text = "Hello, world!" response = self.context.post( - '/services/receivers/simple', - headers=[('x-splunk-input-mode', 'streaming')], - source='sdk', sourcetype='sdk_test', - body=text + "/services/receivers/simple", + headers=[("x-splunk-input-mode", "streaming")], + source="sdk", + sourcetype="sdk_test", + body=text, ) self.assertEqual(response.status, 200) @@ -279,12 +275,18 @@ def test_post_with_get_arguments_to_receivers_stream(self): class TestSocket(BindingTestCase): def test_socket(self): socket = self.context.connect() - socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) - socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) - socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) - socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) - socket.write("\r\n".encode('utf-8')) + socket.write( + ( + f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n" + ).encode("utf-8") + ) + socket.write( + (f"Host: {self.context.host}:{self.context.port}\r\n").encode("utf-8") + ) + socket.write("Accept-Encoding: identity\r\n".encode("utf-8")) + socket.write((f"Authorization: {self.context.token}\r\n").encode("utf-8")) + socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode("utf-8")) + socket.write("\r\n".encode("utf-8")) socket.close() # Sockets take bytes not strings @@ -311,7 +313,7 @@ def test_socket_gethostbyname(self): class TestUnicodeConnect(BindingTestCase): def test_unicode_connect(self): opts = self.opts.kwargs.copy() - opts['host'] = str(opts['host']) + opts["host"] = str(opts["host"]) context = binding.connect(**opts) # Just check to make sure the service is alive response = context.get("/services") @@ -330,16 +332,17 @@ def test_without_autologin(self): self.context.autologin = False self.assertEqual(self.context.get("/services").status, 200) self.context.logout() - self.assertRaises(AuthenticationError, - self.context.get, "/services") + self.assertRaises(AuthenticationError, self.context.get, "/services") class TestAbspath(BindingTestCase): def setUp(self): BindingTestCase.setUp(self) self.kwargs = self.opts.kwargs.copy() - if 'app' in self.kwargs: del self.kwargs['app'] - if 'owner' in self.kwargs: del self.kwargs['owner'] + if "app" in self.kwargs: + del self.kwargs["app"] + if "owner" in self.kwargs: + del self.kwargs["owner"] def test_default(self): path = self.context._abspath("foo", owner=None, app=None) @@ -377,14 +380,16 @@ def test_sharing_global(self): self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_sharing_system(self): - path = self.context._abspath("foo bar", owner="me", app="MyApp", sharing="system") + path = self.context._abspath( + "foo bar", owner="me", app="MyApp", sharing="system" + ) self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/system/foo%20bar") def test_url_forbidden_characters(self): - path = self.context._abspath('/a/b c/d') + path = self.context._abspath("/a/b c/d") self.assertTrue(isinstance(path, UrlEncoded)) - self.assertEqual(path, '/a/b%20c/d') + self.assertEqual(path, "/a/b%20c/d") def test_context_defaults(self): context = binding.connect(**self.kwargs) @@ -412,28 +417,30 @@ def test_context_with_both(self): def test_context_with_user_sharing(self): context = binding.connect( - owner="me", app="MyApp", sharing="user", **self.kwargs) + owner="me", app="MyApp", sharing="user", **self.kwargs + ) path = context._abspath("foo") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/me/MyApp/foo") def test_context_with_app_sharing(self): - context = binding.connect( - owner="me", app="MyApp", sharing="app", **self.kwargs) + context = binding.connect(owner="me", app="MyApp", sharing="app", **self.kwargs) path = context._abspath("foo") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_context_with_global_sharing(self): context = binding.connect( - owner="me", app="MyApp", sharing="global", **self.kwargs) + owner="me", app="MyApp", sharing="global", **self.kwargs + ) path = context._abspath("foo") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/MyApp/foo") def test_context_with_system_sharing(self): context = binding.connect( - owner="me", app="MyApp", sharing="system", **self.kwargs) + owner="me", app="MyApp", sharing="system", **self.kwargs + ) path = context._abspath("foo") self.assertTrue(isinstance(path, UrlEncoded)) self.assertEqual(path, "/servicesNS/nobody/system/foo") @@ -449,53 +456,53 @@ def test_context_with_owner_as_email(self): # An urllib2 based HTTP request handler, used to test the binding layers # support for pluggable request handlers. def urllib2_handler(url, message, **kwargs): - method = message['method'].lower() - data = message.get('body', b"") if method == 'post' else None - headers = dict(message.get('headers', [])) + method = message["method"].lower() + data = message.get("body", b"") if method == "post" else None + headers = dict(message.get("headers", [])) req = Request(url, data, headers) try: response = urlopen(req, context=ssl._create_unverified_context()) # nosemgrep except HTTPError as response: pass # Propagate HTTP errors via the returned response message return { - 'status': response.code, - 'reason': response.msg, - 'headers': dict(response.info()), - 'body': BytesIO(response.read()) + "status": response.code, + "reason": response.msg, + "headers": dict(response.info()), + "body": BytesIO(response.read()), } def isatom(body): """Answers if the given response body looks like ATOM.""" root = XML(body) - return \ - root.tag == XNAME_FEED and \ - root.find(XNAME_AUTHOR) is not None and \ - root.find(XNAME_ID) is not None and \ - root.find(XNAME_TITLE) is not None + return ( + root.tag == XNAME_FEED + and root.find(XNAME_AUTHOR) is not None + and root.find(XNAME_ID) is not None + and root.find(XNAME_TITLE) is not None + ) class TestPluggableHTTP(testlib.SDKTestCase): # Verify pluggable HTTP reqeust handlers. def test_handlers(self): - paths = ["/services", "authentication/users", - "search/jobs"] - handlers = [binding.handler(), # default handler - urllib2_handler] + paths = ["/services", "authentication/users", "search/jobs"] + handlers = [ + binding.handler(), # default handler + urllib2_handler, + ] for handler in handlers: logging.debug("Connecting with handler %s", handler) - context = binding.connect( - handler=handler, - **self.opts.kwargs) + context = binding.connect(handler=handler, **self.opts.kwargs) for path in paths: body = context.get(path).body.read() self.assertTrue(isatom(body)) def urllib2_insert_cookie_handler(url, message, **kwargs): - method = message['method'].lower() - data = message.get('body', b"") if method == 'post' else None - headers = dict(message.get('headers', [])) + method = message["method"].lower() + data = message.get("body", b"") if method == "post" else None + headers = dict(message.get("headers", [])) req = Request(url, data, headers) try: response = urlopen(req, context=ssl._create_unverified_context()) # nosemgrep @@ -506,26 +513,33 @@ def urllib2_insert_cookie_handler(url, message, **kwargs): # An example is "sticky session"/"insert cookie" persistence # of a load balancer for a SHC. header_list = list(response.info().items()) - header_list.append(("Set-Cookie", "BIGipServer_splunk-shc-8089=1234567890.12345.0000; path=/; Httponly; Secure")) + header_list.append( + ( + "Set-Cookie", + "BIGipServer_splunk-shc-8089=1234567890.12345.0000; path=/; Httponly; Secure", + ) + ) header_list.append(("Set-Cookie", "home_made=yummy")) return { - 'status': response.code, - 'reason': response.msg, - 'headers': header_list, - 'body': BytesIO(response.read()) + "status": response.code, + "reason": response.msg, + "headers": header_list, + "body": BytesIO(response.read()), } class TestCookiePersistence(testlib.SDKTestCase): # Verify persistence of 3rd party inserted cookies. def test_3rdPartyInsertedCookiePersistence(self): - paths = ["/services", "authentication/users", - "search/jobs"] - logging.debug("Connecting with urllib2_insert_cookie_handler %s", urllib2_insert_cookie_handler) + paths = ["/services", "authentication/users", "search/jobs"] + logging.debug( + "Connecting with urllib2_insert_cookie_handler %s", + urllib2_insert_cookie_handler, + ) context = binding.connect( - handler=urllib2_insert_cookie_handler, - **self.opts.kwargs) + handler=urllib2_insert_cookie_handler, **self.opts.kwargs + ) persisted_cookies = context.get_cookies() @@ -536,8 +550,10 @@ def test_3rdPartyInsertedCookiePersistence(self): break self.assertEqual(splunk_token_found, True) - self.assertEqual(persisted_cookies['BIGipServer_splunk-shc-8089'], "1234567890.12345.0000") - self.assertEqual(persisted_cookies['home_made'], "yummy") + self.assertEqual( + persisted_cookies["BIGipServer_splunk-shc-8089"], "1234567890.12345.0000" + ) + self.assertEqual(persisted_cookies["home_made"], "yummy") @pytest.mark.smoke @@ -548,12 +564,9 @@ def test_logout(self): self.context.logout() self.assertEqual(self.context.token, binding._NoAuthenticationToken) self.assertEqual(self.context.get_cookies(), {}) - self.assertRaises(AuthenticationError, - self.context.get, "/services") - self.assertRaises(AuthenticationError, - self.context.post, "/services") - self.assertRaises(AuthenticationError, - self.context.delete, "/services") + self.assertRaises(AuthenticationError, self.context.get, "/services") + self.assertRaises(AuthenticationError, self.context.post, "/services") + self.assertRaises(AuthenticationError, self.context.delete, "/services") self.context.login() response = self.context.get("/services") self.assertEqual(response.status, 200) @@ -566,18 +579,22 @@ def setUp(self): # Skip these tests if running below Splunk 6.2, cookie-auth didn't exist before from splunklib import client + service = client.Service(**self.opts.kwargs) # TODO: Workaround the fact that skipTest is not defined by unittest2.TestCase service.login() splver = service.splunk_version if splver[:2] < (6, 2): - self.skipTest("Skipping cookie-auth tests, running in %d.%d.%d, this feature was added in 6.2+" % splver) + self.skipTest( + "Skipping cookie-auth tests, running in %d.%d.%d, this feature was added in 6.2+" + % splver + ) - if getattr(unittest.TestCase, 'assertIsNotNone', None) is None: + if getattr(unittest.TestCase, "assertIsNotNone", None) is None: def assertIsNotNone(self, obj, msg=None): if obj is None: - raise self.failureException(msg or '%r is not None' % obj) + raise self.failureException(msg or "%r is not None" % obj) @pytest.mark.smoke def test_cookie_in_auth_headers(self): @@ -612,8 +629,7 @@ def test_cookie_without_autologin(self): self.assertTrue(self.context.has_cookies()) self.context.logout() self.assertFalse(self.context.has_cookies()) - self.assertRaises(AuthenticationError, - self.context.get, "/services") + self.assertRaises(AuthenticationError, self.context.get, "/services") @pytest.mark.smoke def test_got_updated_cookie_with_get(self): @@ -631,7 +647,9 @@ def test_got_updated_cookie_with_get(self): self.assertEqual(len(old_cookies), 1) self.assertTrue(len(list(new_cookies.values())), 1) self.assertEqual(old_cookies, new_cookies) - self.assertEqual(list(new_cookies.values())[0], list(old_cookies.values())[0]) + self.assertEqual( + list(new_cookies.values())[0], list(old_cookies.values())[0] + ) self.assertTrue(found) @pytest.mark.smoke @@ -658,8 +676,8 @@ def test_login_with_multiple_cookies(self): new_context.get_cookies()[key] = value self.assertEqual(len(new_context.get_cookies()), 2) - self.assertTrue('bad' in list(new_context.get_cookies().keys())) - self.assertTrue('cookie' in list(new_context.get_cookies().values())) + self.assertTrue("bad" in list(new_context.get_cookies().keys())) + self.assertTrue("cookie" in list(new_context.get_cookies().values())) for k, v in self.context.get_cookies().items(): self.assertEqual(new_context.get_cookies()[k], v) @@ -668,10 +686,7 @@ def test_login_with_multiple_cookies(self): @pytest.mark.smoke def test_login_fails_without_cookie_or_token(self): - opts = { - 'host': self.opts.kwargs['host'], - 'port': self.opts.kwargs['port'] - } + opts = {"host": self.opts.kwargs["host"], "port": self.opts.kwargs["port"]} try: binding.connect(**opts) self.fail() @@ -682,71 +697,80 @@ def test_login_fails_without_cookie_or_token(self): class TestNamespace(unittest.TestCase): def test_namespace(self): tests = [ - ({}, - {'sharing': None, 'owner': None, 'app': None}), - - ({'owner': "Bob"}, - {'sharing': None, 'owner': "Bob", 'app': None}), - - ({'app': "search"}, - {'sharing': None, 'owner': None, 'app': "search"}), - - ({'owner': "Bob", 'app': "search"}, - {'sharing': None, 'owner': "Bob", 'app': "search"}), - - ({'sharing': "user", 'owner': "Bob@bob.com"}, - {'sharing': "user", 'owner': "Bob@bob.com", 'app': None}), - - ({'sharing': "user"}, - {'sharing': "user", 'owner': None, 'app': None}), - - ({'sharing': "user", 'owner': "Bob"}, - {'sharing': "user", 'owner': "Bob", 'app': None}), - - ({'sharing': "user", 'app': "search"}, - {'sharing': "user", 'owner': None, 'app': "search"}), - - ({'sharing': "user", 'owner': "Bob", 'app': "search"}, - {'sharing': "user", 'owner': "Bob", 'app': "search"}), - - ({'sharing': "app"}, - {'sharing': "app", 'owner': "nobody", 'app': None}), - - ({'sharing': "app", 'owner': "Bob"}, - {'sharing': "app", 'owner': "nobody", 'app': None}), - - ({'sharing': "app", 'app': "search"}, - {'sharing': "app", 'owner': "nobody", 'app': "search"}), - - ({'sharing': "app", 'owner': "Bob", 'app': "search"}, - {'sharing': "app", 'owner': "nobody", 'app': "search"}), - - ({'sharing': "global"}, - {'sharing': "global", 'owner': "nobody", 'app': None}), - - ({'sharing': "global", 'owner': "Bob"}, - {'sharing': "global", 'owner': "nobody", 'app': None}), - - ({'sharing': "global", 'app': "search"}, - {'sharing': "global", 'owner': "nobody", 'app': "search"}), - - ({'sharing': "global", 'owner': "Bob", 'app': "search"}, - {'sharing': "global", 'owner': "nobody", 'app': "search"}), - - ({'sharing': "system"}, - {'sharing': "system", 'owner': "nobody", 'app': "system"}), - - ({'sharing': "system", 'owner': "Bob"}, - {'sharing': "system", 'owner': "nobody", 'app': "system"}), - - ({'sharing': "system", 'app': "search"}, - {'sharing': "system", 'owner': "nobody", 'app': "system"}), - - ({'sharing': "system", 'owner': "Bob", 'app': "search"}, - {'sharing': "system", 'owner': "nobody", 'app': "system"}), - - ({'sharing': 'user', 'owner': '-', 'app': '-'}, - {'sharing': 'user', 'owner': '-', 'app': '-'})] + ({}, {"sharing": None, "owner": None, "app": None}), + ({"owner": "Bob"}, {"sharing": None, "owner": "Bob", "app": None}), + ({"app": "search"}, {"sharing": None, "owner": None, "app": "search"}), + ( + {"owner": "Bob", "app": "search"}, + {"sharing": None, "owner": "Bob", "app": "search"}, + ), + ( + {"sharing": "user", "owner": "Bob@bob.com"}, + {"sharing": "user", "owner": "Bob@bob.com", "app": None}, + ), + ({"sharing": "user"}, {"sharing": "user", "owner": None, "app": None}), + ( + {"sharing": "user", "owner": "Bob"}, + {"sharing": "user", "owner": "Bob", "app": None}, + ), + ( + {"sharing": "user", "app": "search"}, + {"sharing": "user", "owner": None, "app": "search"}, + ), + ( + {"sharing": "user", "owner": "Bob", "app": "search"}, + {"sharing": "user", "owner": "Bob", "app": "search"}, + ), + ({"sharing": "app"}, {"sharing": "app", "owner": "nobody", "app": None}), + ( + {"sharing": "app", "owner": "Bob"}, + {"sharing": "app", "owner": "nobody", "app": None}, + ), + ( + {"sharing": "app", "app": "search"}, + {"sharing": "app", "owner": "nobody", "app": "search"}, + ), + ( + {"sharing": "app", "owner": "Bob", "app": "search"}, + {"sharing": "app", "owner": "nobody", "app": "search"}, + ), + ( + {"sharing": "global"}, + {"sharing": "global", "owner": "nobody", "app": None}, + ), + ( + {"sharing": "global", "owner": "Bob"}, + {"sharing": "global", "owner": "nobody", "app": None}, + ), + ( + {"sharing": "global", "app": "search"}, + {"sharing": "global", "owner": "nobody", "app": "search"}, + ), + ( + {"sharing": "global", "owner": "Bob", "app": "search"}, + {"sharing": "global", "owner": "nobody", "app": "search"}, + ), + ( + {"sharing": "system"}, + {"sharing": "system", "owner": "nobody", "app": "system"}, + ), + ( + {"sharing": "system", "owner": "Bob"}, + {"sharing": "system", "owner": "nobody", "app": "system"}, + ), + ( + {"sharing": "system", "app": "search"}, + {"sharing": "system", "owner": "nobody", "app": "system"}, + ), + ( + {"sharing": "system", "owner": "Bob", "app": "search"}, + {"sharing": "system", "owner": "nobody", "app": "system"}, + ), + ( + {"sharing": "user", "owner": "-", "app": "-"}, + {"sharing": "user", "owner": "-", "app": "-"}, + ), + ] for kwargs, expected in tests: namespace = binding.namespace(**kwargs) @@ -768,12 +792,14 @@ def setUp(self): self.context = binding.connect(**opts) from splunklib import client + service = client.Service(**opts) - if getattr(unittest.TestCase, 'assertIsNotNone', None) is None: + if getattr(unittest.TestCase, "assertIsNotNone", None) is None: + def assertIsNotNone(self, obj, msg=None): if obj is None: - raise self.failureException(msg or '%r is not None' % obj) + raise self.failureException(msg or "%r is not None" % obj) def test_basic_in_auth_headers(self): self.assertIsNotNone(self.context._auth_headers) @@ -799,19 +825,25 @@ def test_preexisting_token(self): self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) - socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) - socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) - socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) - socket.write("\r\n".encode('utf-8')) + socket.write( + ( + f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n" + ).encode("utf-8") + ) + socket.write( + (f"Host: {self.context.host}:{self.context.port}\r\n").encode("utf-8") + ) + socket.write("Accept-Encoding: identity\r\n".encode("utf-8")) + socket.write((f"Authorization: {self.context.token}\r\n").encode("utf-8")) + socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode("utf-8")) + socket.write("\r\n".encode("utf-8")) socket.close() def test_preexisting_token_sans_splunk(self): token = self.context.token - if token.startswith('Splunk '): - token = token.split(' ', 1)[1] - self.assertFalse(token.startswith('Splunk ')) + if token.startswith("Splunk "): + token = token.split(" ", 1)[1] + self.assertFalse(token.startswith("Splunk ")) else: self.fail('Token did not start with "Splunk ".') opts = self.opts.kwargs.copy() @@ -824,69 +856,97 @@ def test_preexisting_token_sans_splunk(self): self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) - socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) - socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) - socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) - socket.write("\r\n".encode('utf-8')) + socket.write( + ( + f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n" + ).encode("utf-8") + ) + socket.write( + (f"Host: {self.context.host}:{self.context.port}\r\n").encode("utf-8") + ) + socket.write("Accept-Encoding: identity\r\n".encode("utf-8")) + socket.write((f"Authorization: {self.context.token}\r\n").encode("utf-8")) + socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode("utf-8")) + socket.write("\r\n".encode("utf-8")) socket.close() def test_connect_with_preexisting_token_sans_user_and_pass(self): token = self.context.token opts = self.opts.kwargs.copy() - del opts['username'] - del opts['password'] + del opts["username"] + del opts["password"] opts["token"] = token newContext = binding.connect(**opts) - response = newContext.get('/services') + response = newContext.get("/services") self.assertEqual(response.status, 200) socket = newContext.connect() - socket.write((f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n").encode('utf-8')) - socket.write((f"Host: {self.context.host}:{self.context.port}\r\n").encode('utf-8')) - socket.write("Accept-Encoding: identity\r\n".encode('utf-8')) - socket.write((f"Authorization: {self.context.token}\r\n").encode('utf-8')) - socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode('utf-8')) - socket.write("\r\n".encode('utf-8')) + socket.write( + ( + f"POST {self.context._abspath('some/path/to/post/to')} HTTP/1.1\r\n" + ).encode("utf-8") + ) + socket.write( + (f"Host: {self.context.host}:{self.context.port}\r\n").encode("utf-8") + ) + socket.write("Accept-Encoding: identity\r\n".encode("utf-8")) + socket.write((f"Authorization: {self.context.token}\r\n").encode("utf-8")) + socket.write("X-Splunk-Input-Mode: Streaming\r\n".encode("utf-8")) + socket.write("\r\n".encode("utf-8")) socket.close() class TestPostWithBodyParam(unittest.TestCase): - def test_post(self): def handler(url, message, **kwargs): assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" assert message["body"] == b"testkey=testvalue" - return splunklib.data.Record({ - "status": 200, - "headers": [], - }) + return splunklib.data.Record( + { + "status": 200, + "headers": [], + } + ) ctx = binding.Context(handler=handler) - ctx.post("foo/bar", owner="testowner", app="testapp", body={"testkey": "testvalue"}) + ctx.post( + "foo/bar", owner="testowner", app="testapp", body={"testkey": "testvalue"} + ) def test_post_with_params_and_body(self): def handler(url, message, **kwargs): - assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar?extrakey=extraval" + assert ( + url + == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar?extrakey=extraval" + ) assert message["body"] == b"testkey=testvalue" - return splunklib.data.Record({ - "status": 200, - "headers": [], - }) + return splunklib.data.Record( + { + "status": 200, + "headers": [], + } + ) ctx = binding.Context(handler=handler) - ctx.post("foo/bar", extrakey="extraval", owner="testowner", app="testapp", body={"testkey": "testvalue"}) + ctx.post( + "foo/bar", + extrakey="extraval", + owner="testowner", + app="testapp", + body={"testkey": "testvalue"}, + ) def test_post_with_params_and_no_body(self): def handler(url, message, **kwargs): assert url == "https://localhost:8089/servicesNS/testowner/testapp/foo/bar" assert message["body"] == b"extrakey=extraval" - return splunklib.data.Record({ - "status": 200, - "headers": [], - }) + return splunklib.data.Record( + { + "status": 200, + "headers": [], + } + ) ctx = binding.Context(handler=handler) ctx.post("foo/bar", extrakey="extraval", owner="testowner", app="testapp") @@ -908,16 +968,18 @@ def __init__(self, port=9093, **handlers): methods = {"do_" + k: _wrap_handler(v) for (k, v) in handlers.items()} def init(handler_self, socket, address, server): - BaseHTTPServer.BaseHTTPRequestHandler.__init__(handler_self, socket, address, server) + BaseHTTPServer.BaseHTTPRequestHandler.__init__( + handler_self, socket, address, server + ) def log(*args): # To silence server access logs pass methods["__init__"] = init methods["log_message"] = log - Handler = type("Handler", - (BaseHTTPServer.BaseHTTPRequestHandler, object), - methods) + Handler = type( + "Handler", (BaseHTTPServer.BaseHTTPRequestHandler, object), methods + ) self._svr = BaseHTTPServer.HTTPServer(("localhost", port), Handler) def run(): @@ -936,38 +998,43 @@ def __exit__(self, typ, value, traceback): class TestFullPost(unittest.TestCase): - def test_post_with_body_urlencoded(self): def check_response(handler): - length = int(handler.headers.get('content-length', 0)) + length = int(handler.headers.get("content-length", 0)) body = handler.rfile.read(length) - assert body.decode('utf-8') == "foo=bar" + assert body.decode("utf-8") == "foo=bar" with MockServer(POST=check_response): - ctx = binding.connect(port=9093, scheme='http', token="waffle") + ctx = binding.connect(port=9093, scheme="http", token="waffle") ctx.post("/", foo="bar") def test_post_with_body_string(self): def check_response(handler): - length = int(handler.headers.get('content-length', 0)) + length = int(handler.headers.get("content-length", 0)) body = handler.rfile.read(length) - assert handler.headers['content-type'] == 'application/json' + assert handler.headers["content-type"] == "application/json" assert json.loads(body)["baz"] == "baf" with MockServer(POST=check_response): - ctx = binding.connect(port=9093, scheme='http', token="waffle", - headers=[("Content-Type", "application/json")]) + ctx = binding.connect( + port=9093, + scheme="http", + token="waffle", + headers=[("Content-Type", "application/json")], + ) ctx.post("/", foo="bar", body='{"baz": "baf"}') def test_post_with_body_dict(self): def check_response(handler): - length = int(handler.headers.get('content-length', 0)) + length = int(handler.headers.get("content-length", 0)) body = handler.rfile.read(length) - assert handler.headers['content-type'] == 'application/x-www-form-urlencoded' - assert ensure_str(body) in ['baz=baf&hep=cat', 'hep=cat&baz=baf'] + assert ( + handler.headers["content-type"] == "application/x-www-form-urlencoded" + ) + assert ensure_str(body) in ["baz=baf&hep=cat", "hep=cat&baz=baf"] with MockServer(POST=check_response): - ctx = binding.connect(port=9093, scheme='http', token="waffle") + ctx = binding.connect(port=9093, scheme="http", token="waffle") ctx.post("/", foo="bar", body={"baz": "baf", "hep": "cat"}) diff --git a/tests/test_collection.py b/tests/test_collection.py index ec641a6d6..5055ea593 100755 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -22,31 +22,36 @@ from splunklib import client collections = [ - 'apps', - 'event_types', - 'indexes', - 'inputs', - 'jobs', - 'loggers', - 'messages', - 'roles', - 'users' + "apps", + "event_types", + "indexes", + "inputs", + "jobs", + "loggers", + "messages", + "roles", + "users", ] -expected_access_keys = set(['sharing', 'app', 'owner']) -expected_fields_keys = set(['required', 'optional', 'wildcard']) +expected_access_keys = set(["sharing", "app", "owner"]) +expected_fields_keys = set(["required", "optional", "wildcard"]) class CollectionTestCase(testlib.SDKTestCase): def setUp(self): super().setUp() - if self.service.splunk_version[0] >= 5 and 'modular_input_kinds' not in collections: - collections.append('modular_input_kinds') # Not supported before Splunk 5.0 + if ( + self.service.splunk_version[0] >= 5 + and "modular_input_kinds" not in collections + ): + collections.append("modular_input_kinds") # Not supported before Splunk 5.0 else: - logging.info("Skipping modular_input_kinds; not supported by Splunk %s" % \ - '.'.join(str(x) for x in self.service.splunk_version)) + logging.info( + "Skipping modular_input_kinds; not supported by Splunk %s" + % ".".join(str(x) for x in self.service.splunk_version) + ) for saved_search in self.service.saved_searches: - if saved_search.name.startswith('delete-me'): + if saved_search.name.startswith("delete-me"): try: for job in saved_search.history(): job.cancel() @@ -59,18 +64,22 @@ def test_metadata(self): self.assertRaises(client.NotSupportedError, self.service.loggers.itemmeta) self.assertRaises(TypeError, self.service.inputs.itemmeta) for c in collections: - if c in ['jobs', 'loggers', 'inputs', 'modular_input_kinds']: + if c in ["jobs", "loggers", "inputs", "modular_input_kinds"]: continue coll = getattr(self.service, c) metadata = coll.itemmeta() found_access_keys = set(metadata.access.keys()) found_fields_keys = set(metadata.fields.keys()) - self.assertTrue(found_access_keys >= expected_access_keys, - msg='metadata.access is missing keys on ' + \ - f'{coll} (found: {found_access_keys}, expected: {expected_access_keys})') - self.assertTrue(found_fields_keys >= expected_fields_keys, - msg='metadata.fields is missing keys on ' + \ - f'{coll} (found: {found_fields_keys}, expected: {expected_fields_keys})') + self.assertTrue( + found_access_keys >= expected_access_keys, + msg="metadata.access is missing keys on " + + f"{coll} (found: {found_access_keys}, expected: {expected_access_keys})", + ) + self.assertTrue( + found_fields_keys >= expected_fields_keys, + msg="metadata.fields is missing keys on " + + f"{coll} (found: {found_fields_keys}, expected: {expected_fields_keys})", + ) def test_list(self): for coll_name in collections: @@ -79,8 +88,11 @@ def test_list(self): if len(expected) == 0: logging.debug(f"No entities in collection {coll_name}; skipping test.") found = [ent.name for ent in coll.list()][:10] - self.assertEqual(expected, found, - msg=f'on {coll_name} (expected: {expected}, found: {found})') + self.assertEqual( + expected, + found, + msg=f"on {coll_name} (expected: {expected}, found: {found})", + ) def test_list_with_count(self): N = 5 @@ -89,18 +101,25 @@ def test_list_with_count(self): expected = [ent.name for ent in coll.list(count=N + 5)][:N] N = len(expected) # in case there are ") - self.assertEqual(result, {'a': None}) + self.assertEqual(result, {"a": None}) result = data.load("1") - self.assertEqual(result, {'a': "1"}) + self.assertEqual(result, {"a": "1"}) result = data.load("") - self.assertEqual(result, {'a': {'b': None}}) + self.assertEqual(result, {"a": {"b": None}}) result = data.load("1") - self.assertEqual(result, {'a': {'b': '1'}}) + self.assertEqual(result, {"a": {"b": "1"}}) result = data.load("") - self.assertEqual(result, {'a': {'b': [None, None]}}) + self.assertEqual(result, {"a": {"b": [None, None]}}) result = data.load("12") - self.assertEqual(result, {'a': {'b': ['1', '2']}}) + self.assertEqual(result, {"a": {"b": ["1", "2"]}}) result = data.load("") - self.assertEqual(result, {'a': {'b': None, 'c': None}}) + self.assertEqual(result, {"a": {"b": None, "c": None}}) result = data.load("12") - self.assertEqual(result, {'a': {'b': '1', 'c': '2'}}) + self.assertEqual(result, {"a": {"b": "1", "c": "2"}}) result = data.load("1") - self.assertEqual(result, {'a': {'b': {'c': '1'}}}) + self.assertEqual(result, {"a": {"b": {"c": "1"}}}) result = data.load("12") - self.assertEqual(result, {'a': {'b': [{'c': '1'}, '2']}}) + self.assertEqual(result, {"a": {"b": [{"c": "1"}, "2"]}}) - result = data.load('alphabeta') - self.assertEqual(result, {'e': {'a1': ['alpha', 'beta']}}) + result = data.load("alphabeta") + self.assertEqual(result, {"e": {"a1": ["alpha", "beta"]}}) result = data.load("v2") - self.assertEqual(result, {'e': {'a1': ['v2', 'v1']}}) + self.assertEqual(result, {"e": {"a1": ["v2", "v1"]}}) def test_attrs(self): result = data.load("") - self.assertEqual(result, {'e': {'a1': 'v1'}}) + self.assertEqual(result, {"e": {"a1": "v1"}}) result = data.load("") - self.assertEqual(result, {'e': {'a1': 'v1', 'a2': 'v2'}}) + self.assertEqual(result, {"e": {"a1": "v1", "a2": "v2"}}) result = data.load("v2") - self.assertEqual(result, {'e': {'$text': 'v2', 'a1': 'v1'}}) + self.assertEqual(result, {"e": {"$text": "v2", "a1": "v1"}}) result = data.load("2") - self.assertEqual(result, {'e': {'a1': 'v1', 'b': '2'}}) + self.assertEqual(result, {"e": {"a1": "v1", "b": "2"}}) result = data.load("v2bv2") - self.assertEqual(result, {'e': {'a1': 'v1', 'b': 'bv2'}}) + self.assertEqual(result, {"e": {"a1": "v1", "b": "bv2"}}) result = data.load("v2") - self.assertEqual(result, {'e': {'a1': ['v2', 'v1']}}) + self.assertEqual(result, {"e": {"a1": ["v2", "v1"]}}) result = data.load("v2") - self.assertEqual(result, - {'e1': {'a1': 'v1', 'e2': {'$text': 'v2', 'a1': 'v1'}}}) + self.assertEqual( + result, {"e1": {"a1": "v1", "e2": {"$text": "v2", "a1": "v1"}}} + ) def test_real(self): """Test some real Splunk response examples.""" testpath = path.dirname(path.abspath(__file__)) - fh = open(path.join(testpath, "data/services.xml"), 'r') + fh = open(path.join(testpath, "data/services.xml"), "r") result = data.load(fh.read()) - self.assertTrue('feed' in result) - self.assertTrue('author' in result.feed) - self.assertTrue('entry' in result.feed) + self.assertTrue("feed" in result) + self.assertTrue("author" in result.feed) + self.assertTrue("entry" in result.feed) titles = [item.title for item in result.feed.entry] self.assertEqual( titles, - ['alerts', 'apps', 'authentication', 'authorization', 'data', - 'deployment', 'licenser', 'messages', 'configs', 'saved', - 'scheduled', 'search', 'server', 'streams', 'broker', 'clustering', - 'masterlm']) - - fh = open(path.join(testpath, "data/services.server.info.xml"), 'r') + [ + "alerts", + "apps", + "authentication", + "authorization", + "data", + "deployment", + "licenser", + "messages", + "configs", + "saved", + "scheduled", + "search", + "server", + "streams", + "broker", + "clustering", + "masterlm", + ], + ) + + fh = open(path.join(testpath, "data/services.server.info.xml"), "r") result = data.load(fh.read()) - self.assertTrue('feed' in result) - self.assertTrue('author' in result.feed) - self.assertTrue('entry' in result.feed) - self.assertEqual(result.feed.title, 'server-info') - self.assertEqual(result.feed.author.name, 'Splunk') - self.assertEqual(result.feed.entry.content.cpu_arch, 'i386') - self.assertEqual(result.feed.entry.content.os_name, 'Darwin') - self.assertEqual(result.feed.entry.content.os_version, '10.8.0') + self.assertTrue("feed" in result) + self.assertTrue("author" in result.feed) + self.assertTrue("entry" in result.feed) + self.assertEqual(result.feed.title, "server-info") + self.assertEqual(result.feed.author.name, "Splunk") + self.assertEqual(result.feed.entry.content.cpu_arch, "i386") + self.assertEqual(result.feed.entry.content.os_name, "Darwin") + self.assertEqual(result.feed.entry.content.os_version, "10.8.0") def test_invalid(self): if sys.version_info[1] >= 7: self.assertRaises(et.ParseError, data.load, "") else: from xml.etree.ElementTree import ParseError + self.assertRaises(ParseError, data.load, "") self.assertRaises(KeyError, data.load, "a") @@ -135,7 +153,7 @@ def test_dict(self): v1 v2 """) - self.assertEqual(result, {'n1': "v1", 'n2': "v2"}) + self.assertEqual(result, {"n1": "v1", "n2": "v2"}) result = data.load(""" @@ -144,7 +162,7 @@ def test_dict(self): v2 """) - self.assertEqual(result, {'content': {'n1': "v1", 'n2': "v2"}}) + self.assertEqual(result, {"content": {"n1": "v1", "n2": "v2"}}) result = data.load(""" @@ -161,8 +179,9 @@ def test_dict(self): """) - self.assertEqual(result, - {'content': {'n1': {'n1n1': "n1v1"}, 'n2': {'n2n1': "n2v1"}}}) + self.assertEqual( + result, {"content": {"n1": {"n1n1": "n1v1"}, "n2": {"n2n1": "n2v1"}}} + ) result = data.load(""" @@ -174,8 +193,7 @@ def test_dict(self): """) - self.assertEqual(result, - {'content': {'n1': ['1', '2', '3', '4']}}) + self.assertEqual(result, {"content": {"n1": ["1", "2", "3", "4"]}}) def test_list(self): result = data.load("""""") @@ -185,7 +203,7 @@ def test_list(self): 1234 """) - self.assertEqual(result, ['1', '2', '3', '4']) + self.assertEqual(result, ["1", "2", "3", "4"]) result = data.load(""" @@ -193,7 +211,7 @@ def test_list(self): 1234 """) - self.assertEqual(result, {'content': ['1', '2', '3', '4']}) + self.assertEqual(result, {"content": ["1", "2", "3", "4"]}) result = data.load(""" @@ -206,7 +224,7 @@ def test_list(self): """) - self.assertEqual(result, {'content': [['1', '2'], ['3', '4']]}) + self.assertEqual(result, {"content": [["1", "2"], ["3", "4"]]}) result = data.load(""" @@ -217,8 +235,10 @@ def test_list(self): v4 """) - self.assertEqual(result, - {'content': [{'n1': "v1"}, {'n2': "v2"}, {'n3': "v3"}, {'n4': "v4"}]}) + self.assertEqual( + result, + {"content": [{"n1": "v1"}, {"n2": "v2"}, {"n3": "v3"}, {"n4": "v4"}]}, + ) result = data.load(""" @@ -227,23 +247,20 @@ def test_list(self): 0 """) - self.assertEqual(result, - {'build': '101089', 'cpu_arch': 'i386', 'isFree': '0'}) + self.assertEqual(result, {"build": "101089", "cpu_arch": "i386", "isFree": "0"}) def test_record(self): d = data.record() - d.update({'foo': 5, - 'bar.baz': 6, - 'bar.qux': 7, - 'bar.zrp.meep': 8, - 'bar.zrp.peem': 9}) - self.assertEqual(d['foo'], 5) - self.assertEqual(d['bar.baz'], 6) - self.assertEqual(d['bar'], {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem': 9}}) + d.update( + {"foo": 5, "bar.baz": 6, "bar.qux": 7, "bar.zrp.meep": 8, "bar.zrp.peem": 9} + ) + self.assertEqual(d["foo"], 5) + self.assertEqual(d["bar.baz"], 6) + self.assertEqual(d["bar"], {"baz": 6, "qux": 7, "zrp": {"meep": 8, "peem": 9}}) self.assertEqual(d.foo, 5) self.assertEqual(d.bar.baz, 6) - self.assertEqual(d.bar, {'baz': 6, 'qux': 7, 'zrp': {'meep': 8, 'peem': 9}}) - self.assertRaises(KeyError, d.__getitem__, 'boris') + self.assertEqual(d.bar, {"baz": 6, "qux": 7, "zrp": {"meep": 8, "peem": 9}}) + self.assertRaises(KeyError, d.__getitem__, "boris") if __name__ == "__main__": diff --git a/tests/test_event_type.py b/tests/test_event_type.py index c50e1ea3a..cacb95736 100755 --- a/tests/test_event_type.py +++ b/tests/test_event_type.py @@ -30,10 +30,10 @@ def test_create(self): self.assertFalse(self.event_type_name in event_types) kwargs = {} - kwargs['search'] = "index=_internal *" - kwargs['description'] = "An internal event" - kwargs['disabled'] = 1 - kwargs['priority'] = 2 + kwargs["search"] = "index=_internal *" + kwargs["description"] = "An internal event" + kwargs["disabled"] = 1 + kwargs["priority"] = 2 event_type = event_types.create(self.event_type_name, **kwargs) self.assertTrue(self.event_type_name in event_types) @@ -52,8 +52,8 @@ def setUp(self): super().setUp() self.event_type_name = testlib.tmpname() self.event_type = self.service.event_types.create( - self.event_type_name, - search="index=_internal *") + self.event_type_name, search="index=_internal *" + ) def tearDown(self): super().setUp() @@ -68,21 +68,25 @@ def tearDown(self): # self.assertFalse(self.event_type_name in self.service.event_types) def test_update(self): - kwargs = {'search': "index=_audit *", 'description': "An audit event", 'priority': '3'} + kwargs = { + "search": "index=_audit *", + "description": "An audit event", + "priority": "3", + } self.event_type.update(**kwargs) self.event_type.refresh() - self.assertEqual(self.event_type['search'], kwargs['search']) - self.assertEqual(self.event_type['description'], kwargs['description']) - self.assertEqual(self.event_type['priority'], kwargs['priority']) + self.assertEqual(self.event_type["search"], kwargs["search"]) + self.assertEqual(self.event_type["description"], kwargs["description"]) + self.assertEqual(self.event_type["priority"], kwargs["priority"]) def test_enable_disable(self): - self.assertEqual(self.event_type['disabled'], '0') + self.assertEqual(self.event_type["disabled"], "0") self.event_type.disable() self.event_type.refresh() - self.assertEqual(self.event_type['disabled'], '1') + self.assertEqual(self.event_type["disabled"], "1") self.event_type.enable() self.event_type.refresh() - self.assertEqual(self.event_type['disabled'], '0') + self.assertEqual(self.event_type["disabled"], "0") if __name__ == "__main__": diff --git a/tests/test_fired_alert.py b/tests/test_fired_alert.py index 9d16fddc1..803287e08 100755 --- a/tests/test_fired_alert.py +++ b/tests/test_fired_alert.py @@ -27,24 +27,26 @@ def setUp(self): self.saved_search_name = testlib.tmpname() self.assertFalse(self.saved_search_name in saved_searches) query = f"search index={self.index_name}" - kwargs = {'alert_type': 'always', - 'alert.severity': "3", - 'alert.suppress': "0", - 'alert.track': "1", - 'dispatch.earliest_time': "-1h", - 'dispatch.latest_time': "now", - 'is_scheduled': "1", - 'cron_schedule': "* * * * *"} + kwargs = { + "alert_type": "always", + "alert.severity": "3", + "alert.suppress": "0", + "alert.track": "1", + "dispatch.earliest_time": "-1h", + "dispatch.latest_time": "now", + "is_scheduled": "1", + "cron_schedule": "* * * * *", + } self.saved_search = saved_searches.create( - self.saved_search_name, - query, **kwargs) + self.saved_search_name, query, **kwargs + ) def tearDown(self): super().tearDown() if self.service.splunk_version >= (5,): self.service.indexes.delete(self.index_name) for saved_search in self.service.saved_searches: - if saved_search.name.startswith('delete-me'): + if saved_search.name.startswith("delete-me"): self.service.saved_searches.delete(saved_search.name) self.assertFalse(saved_search.name in self.service.saved_searches) self.assertFalse(saved_search.name in self.service.fired_alerts) @@ -60,18 +62,21 @@ def test_alerts_on_events(self): self.assertEqual(len(self.saved_search.fired_alerts), 0) self.index.enable() - self.assertEventuallyTrue(lambda: self.index.refresh() and self.index['disabled'] == '0', timeout=25) + self.assertEventuallyTrue( + lambda: self.index.refresh() and self.index["disabled"] == "0", timeout=25 + ) - eventCount = int(self.index['totalEventCount']) - self.assertEqual(self.index['sync'], '0') - self.assertEqual(self.index['disabled'], '0') + eventCount = int(self.index["totalEventCount"]) + self.assertEqual(self.index["sync"], "0") + self.assertEqual(self.index["disabled"], "0") self.index.refresh() - self.index.submit('This is a test ' + testlib.tmpname(), - sourcetype='sdk_use', host='boris') + self.index.submit( + "This is a test " + testlib.tmpname(), sourcetype="sdk_use", host="boris" + ) def f(): self.index.refresh() - return int(self.index['totalEventCount']) == eventCount + 1 + return int(self.index["totalEventCount"]) == eventCount + 1 self.assertEventuallyTrue(f, timeout=50) diff --git a/tests/test_index.py b/tests/test_index.py index 2582934bb..5135682ad 100755 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -21,13 +21,12 @@ from splunklib import client - class IndexTest(testlib.SDKTestCase): def setUp(self): super().setUp() self.index_name = testlib.tmpname() self.index = self.service.indexes.create(self.index_name) - self.assertEventuallyTrue(lambda: self.index.refresh()['disabled'] == '0') + self.assertEventuallyTrue(lambda: self.index.refresh()["disabled"] == "0") def tearDown(self): super().tearDown() @@ -39,21 +38,27 @@ def tearDown(self): if self.index_name in self.service.indexes: time.sleep(5) self.service.indexes.delete(self.index_name) - self.assertEventuallyTrue(lambda: self.index_name not in self.service.indexes) + self.assertEventuallyTrue( + lambda: self.index_name not in self.service.indexes + ) else: - logging.warning("test_index.py:TestDeleteIndex: Skipped: cannot " - "delete indexes via the REST API in Splunk 4.x") + logging.warning( + "test_index.py:TestDeleteIndex: Skipped: cannot " + "delete indexes via the REST API in Splunk 4.x" + ) def totalEventCount(self): self.index.refresh() - return int(self.index['totalEventCount']) + return int(self.index["totalEventCount"]) def test_delete(self): if self.service.splunk_version >= (5,): self.assertTrue(self.index_name in self.service.indexes) time.sleep(5) self.service.indexes.delete(self.index_name) - self.assertEventuallyTrue(lambda: self.index_name not in self.service.indexes) + self.assertEventuallyTrue( + lambda: self.index_name not in self.service.indexes + ) def test_integrity(self): self.check_entity(self.index) @@ -65,10 +70,10 @@ def test_default(self): def test_disable_enable(self): self.index.disable() self.index.refresh() - self.assertEqual(self.index['disabled'], '1') + self.assertEqual(self.index["disabled"], "1") self.index.enable() self.index.refresh() - self.assertEqual(self.index['disabled'], '0') + self.assertEqual(self.index["disabled"], "0") # def test_submit_and_clean(self): # self.index.refresh() @@ -85,63 +90,78 @@ def test_disable_enable(self): # self.assertEqual(self.index['totalEventCount'], '0') def test_prefresh(self): - self.assertEqual(self.index['disabled'], '0') # Index is prefreshed + self.assertEqual(self.index["disabled"], "0") # Index is prefreshed def test_submit(self): - event_count = int(self.index['totalEventCount']) - self.assertEqual(self.index['sync'], '0') - self.assertEqual(self.index['disabled'], '0') + event_count = int(self.index["totalEventCount"]) + self.assertEqual(self.index["sync"], "0") + self.assertEqual(self.index["disabled"], "0") self.index.submit("Hello again!", sourcetype="Boris", host="meep") - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=50) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=50 + ) def test_submit_namespaced(self): - s = client.connect(**{ - "username": self.service.username, - "password": self.service.password, - "owner": "nobody", - "app": "search" - }) + s = client.connect( + **{ + "username": self.service.username, + "password": self.service.password, + "owner": "nobody", + "app": "search", + } + ) i = s.indexes[self.index_name] - event_count = int(i['totalEventCount']) - self.assertEqual(i['sync'], '0') - self.assertEqual(i['disabled'], '0') + event_count = int(i["totalEventCount"]) + self.assertEqual(i["sync"], "0") + self.assertEqual(i["disabled"], "0") i.submit("Hello again namespaced!", sourcetype="Boris", host="meep") - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=50) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=50 + ) def test_submit_via_attach(self): - event_count = int(self.index['totalEventCount']) + event_count = int(self.index["totalEventCount"]) cn = self.index.attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=60 + ) def test_submit_via_attach_using_token_header(self): # Remove the prefix from the token - s = client.connect(**{'token': self.service.token.replace("Splunk ", "")}) + s = client.connect(**{"token": self.service.token.replace("Splunk ", "")}) i = s.indexes[self.index_name] - event_count = int(i['totalEventCount']) + event_count = int(i["totalEventCount"]) if s.has_cookies(): del s.http._cookies cn = i.attach() cn.send(b"Hello Boris 5!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=60 + ) def test_submit_via_attached_socket(self): - event_count = int(self.index['totalEventCount']) + event_count = int(self.index["totalEventCount"]) f = self.index.attached_socket with f() as sock: - sock.send(b'Hello world!\r\n') - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) + sock.send(b"Hello world!\r\n") + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=60 + ) def test_submit_via_attach_with_cookie_header(self): # Skip this test if running below Splunk 6.2, cookie-auth didn't exist before splver = self.service.splunk_version if splver[:2] < (6, 2): - self.skipTest("Skipping cookie-auth tests, running in %d.%d.%d, this feature was added in 6.2+" % splver) + self.skipTest( + "Skipping cookie-auth tests, running in %d.%d.%d, this feature was added in 6.2+" + % splver + ) - event_count = int(self.service.indexes[self.index_name]['totalEventCount']) + event_count = int(self.service.indexes[self.index_name]["totalEventCount"]) cookie = "%s=%s" % (list(self.service.http._cookies.items())[0]) service = client.Service(**{"cookie": cookie}) @@ -149,32 +169,41 @@ def test_submit_via_attach_with_cookie_header(self): cn = service.indexes[self.index_name].attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=60 + ) def test_submit_via_attach_with_multiple_cookie_headers(self): # Skip this test if running below Splunk 6.2, cookie-auth didn't exist before splver = self.service.splunk_version if splver[:2] < (6, 2): - self.skipTest("Skipping cookie-auth tests, running in %d.%d.%d, this feature was added in 6.2+" % splver) + self.skipTest( + "Skipping cookie-auth tests, running in %d.%d.%d, this feature was added in 6.2+" + % splver + ) - event_count = int(self.service.indexes[self.index_name]['totalEventCount']) - service = client.Service(**{"cookie": 'a bad cookie'}) + event_count = int(self.service.indexes[self.index_name]["totalEventCount"]) + service = client.Service(**{"cookie": "a bad cookie"}) service.http._cookies.update(self.service.http._cookies) service.login() cn = service.indexes[self.index_name].attach() cn.send(b"Hello Boris!\r\n") cn.close() - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 1, timeout=60) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 1, timeout=60 + ) @pytest.mark.app def test_upload(self): self.install_app_from_collection("file_to_upload") - event_count = int(self.index['totalEventCount']) + event_count = int(self.index["totalEventCount"]) path = self.pathInApp("file_to_upload", ["log.txt"]) self.index.upload(path) - self.assertEventuallyTrue(lambda: self.totalEventCount() == event_count + 4, timeout=60) + self.assertEventuallyTrue( + lambda: self.totalEventCount() == event_count + 4, timeout=60 + ) if __name__ == "__main__": diff --git a/tests/test_input.py b/tests/test_input.py index 53436f73f..fa1663ec4 100755 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -25,7 +25,7 @@ def highest_port(service, base_port, *kinds): """Find the first port >= base_port not in use by any input in kinds.""" highest_port = base_port for input in service.inputs.list(*kinds): - port = int(input.name.split(':')[-1]) + port = int(input.name.split(":")[-1]) highest_port = max(port, highest_port) return highest_port @@ -33,11 +33,13 @@ def highest_port(service, base_port, *kinds): class TestTcpInputNameHandling(testlib.SDKTestCase): def setUp(self): super().setUp() - self.base_port = highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp') + 1 + self.base_port = ( + highest_port(self.service, 10000, "tcp", "splunktcp", "udp") + 1 + ) def tearDown(self): - for input in self.service.inputs.list('tcp', 'splunktcp'): - port = int(input.name.split(':')[-1]) + for input in self.service.inputs.list("tcp", "splunktcp"): + port = int(input.name.split(":")[-1]) if port >= self.base_port: input.delete() super().tearDown() @@ -53,22 +55,25 @@ def create_tcp_input(self, base_port, kind, **options): port += 1 def test_create_tcp_port(self): - for kind in ['tcp', 'splunktcp']: + for kind in ["tcp", "splunktcp"]: input = self.service.inputs.create(str(self.base_port), kind) self.check_entity(input) input.delete() def test_cannot_create_with_restrictToHost_in_name(self): self.assertRaises( - client.HTTPError, - lambda: self.service.inputs.create('boris:10000', 'tcp') + client.HTTPError, lambda: self.service.inputs.create("boris:10000", "tcp") ) def test_create_tcp_ports_with_restrictToHost(self): - for kind in ['tcp', 'splunktcp']: # Multiplexed UDP ports are not supported + for kind in ["tcp", "splunktcp"]: # Multiplexed UDP ports are not supported # Make sure we can create two restricted inputs on the same port - boris = self.service.inputs.create(str(self.base_port), kind, restrictToHost='boris') - natasha = self.service.inputs.create(str(self.base_port), kind, restrictToHost='natasha') + boris = self.service.inputs.create( + str(self.base_port), kind, restrictToHost="boris" + ) + natasha = self.service.inputs.create( + str(self.base_port), kind, restrictToHost="natasha" + ) # And that they both function boris.refresh() natasha.refresh() @@ -90,22 +95,24 @@ def test_create_tcp_ports_with_restrictToHost(self): # restricted.delete() def test_unrestricted_to_restricted_collision(self): - for kind in ['tcp', 'splunktcp', 'udp']: + for kind in ["tcp", "splunktcp", "udp"]: unrestricted = self.service.inputs.create(str(self.base_port), kind) self.assertTrue(str(self.base_port) in self.service.inputs) self.assertRaises( client.HTTPError, - lambda: self.service.inputs.create(str(self.base_port), kind, restrictToHost='boris') + lambda: self.service.inputs.create( + str(self.base_port), kind, restrictToHost="boris" + ), ) unrestricted.delete() def test_update_restrictToHost_fails(self): - for kind in ['tcp', 'splunktcp']: # No UDP, since it's broken in Splunk - boris = self.create_tcp_input(self.base_port, kind, restrictToHost='boris') + for kind in ["tcp", "splunktcp"]: # No UDP, since it's broken in Splunk + boris = self.create_tcp_input(self.base_port, kind, restrictToHost="boris") self.assertRaises( client.IllegalOperationException, - lambda: boris.update(restrictToHost='hilda') + lambda: boris.update(restrictToHost="hilda"), ) @@ -128,7 +135,7 @@ def test_read_kind(self): self.assertEqual(item.kind, kind) def test_inputs_list_on_one_kind(self): - self.service.inputs.list('monitor') + self.service.inputs.list("monitor") def test_read_invalid_input(self): name = testlib.tmpname() @@ -139,72 +146,81 @@ def test_read_invalid_input(self): self.assertTrue("HTTP 404 Not Found" in str(he)) def test_inputs_list_on_one_kind_with_count(self): - expected = [x.name for x in self.service.inputs.list('monitor')[:10]] - found = [x.name for x in self.service.inputs.list('monitor', count=10)] + expected = [x.name for x in self.service.inputs.list("monitor")[:10]] + found = [x.name for x in self.service.inputs.list("monitor", count=10)] self.assertEqual(expected, found) def test_inputs_list_on_one_kind_with_offset(self): N = 2 - expected = [x.name for x in self.service.inputs.list('monitor')[N:]] - found = [x.name for x in self.service.inputs.list('monitor', offset=N)] + expected = [x.name for x in self.service.inputs.list("monitor")[N:]] + found = [x.name for x in self.service.inputs.list("monitor", offset=N)] self.assertEqual(expected, found) def test_inputs_list_on_one_kind_with_search(self): search = "SPLUNK" - expected = [x.name for x in self.service.inputs.list('monitor') if search in x.name] - found = [x.name for x in self.service.inputs.list('monitor', search=search)] + expected = [ + x.name for x in self.service.inputs.list("monitor") if search in x.name + ] + found = [x.name for x in self.service.inputs.list("monitor", search=search)] self.assertEqual(expected, found) @pytest.mark.app def test_oneshot(self): - self.install_app_from_collection('file_to_upload') + self.install_app_from_collection("file_to_upload") index_name = testlib.tmpname() index = self.service.indexes.create(index_name) - self.assertEventuallyTrue(lambda: index.refresh() and index['disabled'] == '0') + self.assertEventuallyTrue(lambda: index.refresh() and index["disabled"] == "0") - eventCount = int(index['totalEventCount']) + eventCount = int(index["totalEventCount"]) path = self.pathInApp("file_to_upload", ["log.txt"]) self.service.inputs.oneshot(path, index=index_name) def f(): index.refresh() - return int(index['totalEventCount']) == eventCount + 4 + return int(index["totalEventCount"]) == eventCount + 4 self.assertEventuallyTrue(f, timeout=60) def test_oneshot_on_nonexistant_file(self): name = testlib.tmpname() - self.assertRaises(HTTPError, - self.service.inputs.oneshot, name) + self.assertRaises(HTTPError, self.service.inputs.oneshot, name) class TestInput(testlib.SDKTestCase): def setUp(self): super().setUp() inputs = self.service.inputs - unrestricted_port = str(highest_port(self.service, 10000, 'tcp', 'splunktcp', 'udp') + 1) - restricted_port = str(highest_port(self.service, int(unrestricted_port) + 1, 'tcp', 'splunktcp') + 1) - test_inputs = [{'kind': 'tcp', 'name': unrestricted_port, 'host': 'sdk-test'}, - {'kind': 'udp', 'name': unrestricted_port, 'host': 'sdk-test'}, - {'kind': 'tcp', 'name': 'boris:' + restricted_port, 'host': 'sdk-test'}] + unrestricted_port = str( + highest_port(self.service, 10000, "tcp", "splunktcp", "udp") + 1 + ) + restricted_port = str( + highest_port(self.service, int(unrestricted_port) + 1, "tcp", "splunktcp") + + 1 + ) + test_inputs = [ + {"kind": "tcp", "name": unrestricted_port, "host": "sdk-test"}, + {"kind": "udp", "name": unrestricted_port, "host": "sdk-test"}, + {"kind": "tcp", "name": "boris:" + restricted_port, "host": "sdk-test"}, + ] self._test_entities = {} - self._test_entities['tcp'] = \ - inputs.create(unrestricted_port, 'tcp', host='sdk-test') - self._test_entities['udp'] = \ - inputs.create(unrestricted_port, 'udp', host='sdk-test') - self._test_entities['restrictedTcp'] = \ - inputs.create(restricted_port, 'tcp', restrictToHost='boris') + self._test_entities["tcp"] = inputs.create( + unrestricted_port, "tcp", host="sdk-test" + ) + self._test_entities["udp"] = inputs.create( + unrestricted_port, "udp", host="sdk-test" + ) + self._test_entities["restrictedTcp"] = inputs.create( + restricted_port, "tcp", restrictToHost="boris" + ) def tearDown(self): super().tearDown() for entity in self._test_entities.values(): try: - self.service.inputs.delete( - kind=entity.kind, - name=entity.name) + self.service.inputs.delete(kind=entity.kind, name=entity.name) except KeyError: pass @@ -223,11 +239,11 @@ def test_lists_modular_inputs(self): self.uncheckedRestartSplunk() inputs = self.service.inputs - if ('abcd', 'test2') not in inputs: - inputs.create('abcd', 'test2', field1='boris') + if ("abcd", "test2") not in inputs: + inputs.create("abcd", "test2", field1="boris") - input = inputs['abcd', 'test2'] - self.assertEqual(input.field1, 'boris') + input = inputs["abcd", "test2"] + self.assertEqual(input.field1, "boris") def test_create(self): inputs = self.service.inputs @@ -238,7 +254,7 @@ def test_create(self): def test_get_kind_list(self): inputs = self.service.inputs kinds = inputs._get_kind_list() - self.assertTrue('tcp/raw' in kinds) + self.assertTrue("tcp/raw" in kinds) def test_read(self): inputs = self.service.inputs @@ -250,22 +266,23 @@ def test_read(self): self.assertEqual(this_entity.host, read_entity.host) def test_read_indiviually(self): - tcp_input = self.service.input(self._test_entities['tcp'].path, - self._test_entities['tcp'].kind) + tcp_input = self.service.input( + self._test_entities["tcp"].path, self._test_entities["tcp"].kind + ) self.assertIsNotNone(tcp_input) - self.assertTrue('tcp', tcp_input.kind) - self.assertTrue(self._test_entities['tcp'].name, tcp_input.name) + self.assertTrue("tcp", tcp_input.kind) + self.assertTrue(self._test_entities["tcp"].name, tcp_input.name) def test_update(self): inputs = self.service.inputs for entity in self._test_entities.values(): kind, name = entity.kind, entity.name - kwargs = {'host': 'foo'} + kwargs = {"host": "foo"} entity.update(**kwargs) entity.refresh() - self.assertEqual(entity.host, kwargs['host']) + self.assertEqual(entity.host, kwargs["host"]) - @pytest.mark.skip('flaky') + @pytest.mark.skip("flaky") def test_delete(self): inputs = self.service.inputs remaining = len(self._test_entities) - 1 @@ -278,13 +295,13 @@ def test_delete(self): inputs.delete(name) self.assertFalse(name in inputs) else: - if not name.startswith('boris'): - self.assertRaises(client.AmbiguousReferenceException, - inputs.delete, name) + if not name.startswith("boris"): + self.assertRaises( + client.AmbiguousReferenceException, inputs.delete, name + ) self.service.inputs.delete(name, kind) self.assertFalse((name, kind) in inputs) - self.assertRaises(client.HTTPError, - input_entity.refresh) + self.assertRaises(client.HTTPError, input_entity.refresh) remaining -= 1 diff --git a/tests/test_job.py b/tests/test_job.py index a276e212b..c1c9ab004 100755 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -33,12 +33,14 @@ class TestUtilities(testlib.SDKTestCase): def test_service_search(self): - job = self.service.search('search index=_internal earliest=-1m | head 3') + job = self.service.search("search index=_internal earliest=-1m | head 3") self.assertTrue(job.sid in self.service.jobs) job.cancel() def test_create_job_with_output_mode_json(self): - job = self.service.jobs.create(query='search index=_internal earliest=-1m | head 3', output_mode='json') + job = self.service.jobs.create( + query="search index=_internal earliest=-1m | head 3", output_mode="json" + ) self.assertTrue(job.sid in self.service.jobs) job.cancel() @@ -49,12 +51,13 @@ def test_oneshot_with_garbage_fails(self): @pytest.mark.smoke def test_oneshot(self): jobs = self.service.jobs - stream = jobs.oneshot("search index=_internal earliest=-1m | head 3", output_mode='json') + stream = jobs.oneshot( + "search index=_internal earliest=-1m | head 3", output_mode="json" + ) result = results.JSONResultsReader(stream) ds = list(result) self.assertEqual(result.is_preview, False) - self.assertTrue(isinstance(ds[0], dict) or \ - isinstance(ds[0], results.Message)) + self.assertTrue(isinstance(ds[0], dict) or isinstance(ds[0], results.Message)) nonmessages = [d for d in ds if isinstance(d, dict)] self.assertTrue(len(nonmessages) <= 3) @@ -64,75 +67,84 @@ def test_export_with_garbage_fails(self): def test_export(self): jobs = self.service.jobs - stream = jobs.export("search index=_internal earliest=-1m | head 3", output_mode='json') + stream = jobs.export( + "search index=_internal earliest=-1m | head 3", output_mode="json" + ) result = results.JSONResultsReader(stream) ds = list(result) self.assertEqual(result.is_preview, False) - self.assertTrue(isinstance(ds[0], dict) or \ - isinstance(ds[0], results.Message)) + self.assertTrue(isinstance(ds[0], dict) or isinstance(ds[0], results.Message)) nonmessages = [d for d in ds if isinstance(d, dict)] self.assertTrue(len(nonmessages) <= 3) def test_export_docstring_sample(self): from splunklib import client from splunklib import results - service = self.service # cheat - rr = results.JSONResultsReader(service.jobs.export("search * | head 5", output_mode='json')) + + service = self.service # cheat + rr = results.JSONResultsReader( + service.jobs.export("search * | head 5", output_mode="json") + ) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print(f'{result.type}: {result.message}') + pass # print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print(result) + pass # print(result) assert rr.is_preview == False def test_results_docstring_sample(self): from splunklib import results + service = self.service # cheat job = service.jobs.create("search * | head 5") while not job.is_done(): sleep(0.2) - rr = results.JSONResultsReader(job.results(output_mode='json')) + rr = results.JSONResultsReader(job.results(output_mode="json")) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print(f'{result.type}: {result.message}') + pass # print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print(result) + pass # print(result) assert rr.is_preview == False def test_preview_docstring_sample(self): from splunklib import client from splunklib import results - service = self.service # cheat + + service = self.service # cheat job = service.jobs.create("search * | head 5") - rr = results.JSONResultsReader(job.preview(output_mode='json')) + rr = results.JSONResultsReader(job.preview(output_mode="json")) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print(f'{result.type}: {result.message}') + pass # print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print(result) + pass # print(result) if rr.is_preview: - pass #print("Preview of a running search job.") + pass # print("Preview of a running search job.") else: - pass #print("Job is finished. Results are final.") + pass # print("Job is finished. Results are final.") def test_oneshot_docstring_sample(self): from splunklib import client from splunklib import results - service = self.service # cheat - rr = results.JSONResultsReader(service.jobs.oneshot("search * | head 5", output_mode='json')) + + service = self.service # cheat + rr = results.JSONResultsReader( + service.jobs.oneshot("search * | head 5", output_mode="json") + ) for result in rr: if isinstance(result, results.Message): # Diagnostic messages may be returned in the results - pass #print(f'{result.type}: {result.message}') + pass # print(f'{result.type}: {result.message}') elif isinstance(result, dict): # Normal events are returned as dicts - pass #print(result) + pass # print(result) assert rr.is_preview == False def test_normal_job_with_garbage_fails(self): @@ -141,41 +153,76 @@ def test_normal_job_with_garbage_fails(self): bad_search = "abcd|asfwqqq" jobs.create(bad_search) except client.HTTPError as he: - self.assertTrue('abcd' in str(he)) + self.assertTrue("abcd" in str(he)) return self.fail("Job with garbage search failed to raise TypeError.") def test_cancel(self): jobs = self.service.jobs - job = jobs.create(query="search index=_internal | head 3", - earliest_time="-1m", - latest_time="now") + job = jobs.create( + query="search index=_internal | head 3", + earliest_time="-1m", + latest_time="now", + ) self.assertTrue(job.sid in jobs) job.cancel() self.assertFalse(job.sid in jobs) def test_cancel_is_idempotent(self): jobs = self.service.jobs - job = jobs.create(query="search index=_internal | head 3", - earliest_time="-1m", - latest_time="now") + job = jobs.create( + query="search index=_internal | head 3", + earliest_time="-1m", + latest_time="now", + ) self.assertTrue(job.sid in jobs) job.cancel() - job.cancel() # Second call should be nop + job.cancel() # Second call should be nop def check_job(self, job): self.check_entity(job) - keys = ['cursorTime', 'delegate', 'diskUsage', 'dispatchState', - 'doneProgress', 'dropCount', 'earliestTime', 'eventAvailableCount', - 'eventCount', 'eventFieldCount', 'eventIsStreaming', - 'eventIsTruncated', 'eventSearch', 'eventSorting', 'isDone', - 'isFailed', 'isFinalized', 'isPaused', 'isPreviewEnabled', - 'isRealTimeSearch', 'isRemoteTimeline', 'isSaved', 'isSavedSearch', - 'isZombie', 'keywords', 'label', 'messages', - 'numPreviews', 'priority', 'remoteSearch', 'reportSearch', - 'resultCount', 'resultIsStreaming', 'resultPreviewCount', - 'runDuration', 'scanCount', 'searchProviders', 'sid', - 'statusBuckets', 'ttl'] + keys = [ + "cursorTime", + "delegate", + "diskUsage", + "dispatchState", + "doneProgress", + "dropCount", + "earliestTime", + "eventAvailableCount", + "eventCount", + "eventFieldCount", + "eventIsStreaming", + "eventIsTruncated", + "eventSearch", + "eventSorting", + "isDone", + "isFailed", + "isFinalized", + "isPaused", + "isPreviewEnabled", + "isRealTimeSearch", + "isRemoteTimeline", + "isSaved", + "isSavedSearch", + "isZombie", + "keywords", + "label", + "messages", + "numPreviews", + "priority", + "remoteSearch", + "reportSearch", + "resultCount", + "resultIsStreaming", + "resultPreviewCount", + "runDuration", + "scanCount", + "searchProviders", + "sid", + "statusBuckets", + "ttl", + ] for key in keys: self.assertTrue(key in job.content) @@ -199,6 +246,7 @@ def test_get_job(self): self.assertEqual(10, int(job["eventCount"])) self.assertEqual(10, int(job["resultCount"])) + class TestJobWithDelayedDone(testlib.SDKTestCase): def setUp(self): super().setUp() @@ -216,20 +264,18 @@ def test_enable_preview(self): sleep_duration = 100 self.query = "search index=_internal | sleep %d" % sleep_duration self.job = self.service.jobs.create( - query=self.query, - earliest_time="-1m", - priority=5, - latest_time="now") + query=self.query, earliest_time="-1m", priority=5, latest_time="now" + ) while not self.job.is_ready(): pass - self.assertEqual(self.job.content['isPreviewEnabled'], '0') + self.assertEqual(self.job.content["isPreviewEnabled"], "0") self.job.enable_preview() def is_preview_enabled(): is_done = self.job.is_done() if is_done: - self.fail('Job finished before preview enabled.') - return self.job.content['isPreviewEnabled'] == '1' + self.fail("Job finished before preview enabled.") + return self.job.content["isPreviewEnabled"] == "1" self.assertEventuallyTrue(is_preview_enabled) @@ -239,10 +285,8 @@ def test_setpriority(self): sleep_duration = 100 self.query = "search index=_internal | sleep %s" % sleep_duration self.job = self.service.jobs.create( - query=self.query, - earliest_time="-1m", - priority=5, - latest_time="now") + query=self.query, earliest_time="-1m", priority=5, latest_time="now" + ) # Note: You can only *decrease* the priority (i.e., 5 decreased to 3) of # a job unless Splunk is running as root. This is because Splunk jobs @@ -250,7 +294,7 @@ def test_setpriority(self): if self.service._splunk_version[0] < 6: # BUG: Splunk 6 doesn't return priority until job is ready - old_priority = int(self.job.content['priority']) + old_priority = int(self.job.content["priority"]) self.assertEqual(5, old_priority) new_priority = 3 @@ -264,7 +308,7 @@ def test_setpriority(self): def f(): if self.job.is_done(): self.fail("Job already done before priority was set.") - return int(self.job.content['priority']) == new_priority + return int(self.job.content["priority"]) == new_priority self.assertEventuallyTrue(f, timeout=sleep_duration + 5) @@ -274,9 +318,8 @@ def setUp(self): super().setUp() self.query = "search index=_internal | head 3" self.job = self.service.jobs.create( - query=self.query, - earliest_time="-1m", - latest_time="now") + query=self.query, earliest_time="-1m", latest_time="now" + ) def tearDown(self): super().tearDown() @@ -285,13 +328,13 @@ def tearDown(self): @_log_duration def test_get_preview_and_events(self): self.assertEventuallyTrue(self.job.is_done) - self.assertLessEqual(int(self.job['eventCount']), 3) + self.assertLessEqual(int(self.job["eventCount"]), 3) - preview_stream = self.job.preview(output_mode='json') + preview_stream = self.job.preview(output_mode="json") preview_r = results.JSONResultsReader(preview_stream) self.assertFalse(preview_r.is_preview) - events_stream = self.job.events(output_mode='json') + events_stream = self.job.events(output_mode="json") events_r = results.JSONResultsReader(events_stream) n_events = len([x for x in events_r if isinstance(x, dict)]) @@ -299,40 +342,41 @@ def test_get_preview_and_events(self): self.assertEqual(n_events, n_preview) def test_pause(self): - if self.job['isPaused'] == '1': + if self.job["isPaused"] == "1": self.job.unpause() self.job.refresh() - self.assertEqual(self.job['isPaused'], '0') + self.assertEqual(self.job["isPaused"], "0") self.job.pause() - self.assertEventuallyTrue(lambda: self.job.refresh()['isPaused'] == '1') + self.assertEventuallyTrue(lambda: self.job.refresh()["isPaused"] == "1") def test_unpause(self): - if self.job['isPaused'] == '0': + if self.job["isPaused"] == "0": self.job.pause() self.job.refresh() - self.assertEqual(self.job['isPaused'], '1') + self.assertEqual(self.job["isPaused"], "1") self.job.unpause() - self.assertEventuallyTrue(lambda: self.job.refresh()['isPaused'] == '0') + self.assertEventuallyTrue(lambda: self.job.refresh()["isPaused"] == "0") def test_finalize(self): - if self.job['isFinalized'] == '1': + if self.job["isFinalized"] == "1": self.fail("Job is already finalized; can't test .finalize() method.") else: self.job.finalize() - self.assertEventuallyTrue(lambda: self.job.refresh()['isFinalized'] == '1') + self.assertEventuallyTrue(lambda: self.job.refresh()["isFinalized"] == "1") def test_setttl(self): - old_ttl = int(self.job['ttl']) + old_ttl = int(self.job["ttl"]) new_ttl = old_ttl + 1000 from datetime import datetime + start_time = datetime.now() self.job.set_ttl(new_ttl) tries = 3 while True: self.job.refresh() - ttl = int(self.job['ttl']) + ttl = int(self.job["ttl"]) if ttl <= new_ttl and ttl > old_ttl: break else: @@ -354,16 +398,15 @@ def test_touch(self): # Touch will increase the updated time self.assertLess(old_updated, new_updated) - def test_search_invalid_query_as_json(self): - args = { - 'output_mode': 'json', - 'exec_mode': 'normal' - } + args = {"output_mode": "json", "exec_mode": "normal"} try: - self.service.jobs.create('invalid query', **args) + self.service.jobs.create("invalid query", **args) except SyntaxError as pe: - self.fail("Something went wrong with parsing the REST API response. %s" % pe.message) + self.fail( + "Something went wrong with parsing the REST API response. %s" + % pe.message + ) except HTTPError as he: self.assertEqual(he.status, 400) except Exception as e: @@ -372,18 +415,18 @@ def test_search_invalid_query_as_json(self): @pytest.mark.smoke def test_v1_job_fallback(self): self.assertEventuallyTrue(self.job.is_done) - self.assertLessEqual(int(self.job['eventCount']), 3) + self.assertLessEqual(int(self.job["eventCount"]), 3) - preview_stream = self.job.preview(output_mode='json', search='| head 1') + preview_stream = self.job.preview(output_mode="json", search="| head 1") preview_r = results.JSONResultsReader(preview_stream) self.assertFalse(preview_r.is_preview) - events_stream = self.job.events(output_mode='json', search='| head 1') + events_stream = self.job.events(output_mode="json", search="| head 1") events_r = results.JSONResultsReader(events_stream) - - results_stream = self.job.results(output_mode='json', search='| head 1') + + results_stream = self.job.results(output_mode="json", search="| head 1") results_r = results.JSONResultsReader(results_stream) - + n_events = len([x for x in events_r if isinstance(x, dict)]) n_preview = len([x for x in preview_r if isinstance(x, dict)]) n_results = len([x for x in results_r if isinstance(x, dict)]) @@ -399,15 +442,17 @@ def test_results_reader(self): # Run jobs.export("search index=_internal | stats count", # earliest_time="rt", latest_time="rt") and you get a # streaming sequence of XML fragments containing results. - with io.open('data/results.xml', mode='br') as input: + with io.open("data/results.xml", mode="br") as input: reader = results.ResultsReader(input) self.assertFalse(reader.is_preview) N_results = 0 N_messages = 0 for r in reader: from collections import OrderedDict - self.assertTrue(isinstance(r, OrderedDict) - or isinstance(r, results.Message)) + + self.assertTrue( + isinstance(r, OrderedDict) or isinstance(r, results.Message) + ) if isinstance(r, OrderedDict): N_results += 1 elif isinstance(r, results.Message): @@ -419,14 +464,16 @@ def test_results_reader_with_streaming_results(self): # Run jobs.export("search index=_internal | stats count", # earliest_time="rt", latest_time="rt") and you get a # streaming sequence of XML fragments containing results. - with io.open('data/streaming_results.xml', 'br') as input: + with io.open("data/streaming_results.xml", "br") as input: reader = results.ResultsReader(input) N_results = 0 N_messages = 0 for r in reader: from collections import OrderedDict - self.assertTrue(isinstance(r, OrderedDict) - or isinstance(r, results.Message)) + + self.assertTrue( + isinstance(r, OrderedDict) or isinstance(r, results.Message) + ) if isinstance(r, OrderedDict): N_results += 1 elif isinstance(r, results.Message): @@ -435,15 +482,21 @@ def test_results_reader_with_streaming_results(self): self.assertEqual(N_messages, 3) def test_xmldtd_filter(self): - s = results._XMLDTDFilter(BytesIO(b"""Other stuf ab""")) + s = results._XMLDTDFilter( + BytesIO( + b"""Other stuf ab""" + ) + ) self.assertEqual(s.read(), b"Other stuf ab") def test_concatenated_stream(self): - s = results._ConcatenatedStream(BytesIO(b"This is a test "), - BytesIO(b"of the emergency broadcast system.")) + s = results._ConcatenatedStream( + BytesIO(b"This is a test "), BytesIO(b"of the emergency broadcast system.") + ) self.assertEqual(s.read(3), b"Thi") - self.assertEqual(s.read(20), b's is a test of the e') - self.assertEqual(s.read(), b'mergency broadcast system.') + self.assertEqual(s.read(20), b"s is a test of the e") + self.assertEqual(s.read(), b"mergency broadcast system.") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_kvstore_batch.py b/tests/test_kvstore_batch.py index a17a3e9d4..5cba9085a 100755 --- a/tests/test_kvstore_batch.py +++ b/tests/test_kvstore_batch.py @@ -20,52 +20,55 @@ class KVStoreBatchTestCase(testlib.SDKTestCase): def setUp(self): super().setUp() - self.service.namespace['app'] = 'search' + self.service.namespace["app"] = "search" confs = self.service.kvstore - if 'test' in confs: - confs['test'].delete() - confs.create('test') + if "test" in confs: + confs["test"].delete() + confs.create("test") - self.col = confs['test'].data + self.col = confs["test"].data def test_insert_find_update_data(self): - data = [{'_key': str(x), 'data': '#' + str(x), 'num': x} for x in range(1000)] + data = [{"_key": str(x), "data": "#" + str(x), "num": x} for x in range(1000)] self.col.batch_save(*data) - testData = self.col.query(sort='num') + testData = self.col.query(sort="num") self.assertEqual(len(testData), 1000) for x in range(1000): - self.assertEqual(testData[x]['_key'], str(x)) - self.assertEqual(testData[x]['data'], '#' + str(x)) - self.assertEqual(testData[x]['num'], x) - - data = [{'_key': str(x), 'data': '#' + str(x + 1), 'num': x + 1} for x in range(1000)] + self.assertEqual(testData[x]["_key"], str(x)) + self.assertEqual(testData[x]["data"], "#" + str(x)) + self.assertEqual(testData[x]["num"], x) + + data = [ + {"_key": str(x), "data": "#" + str(x + 1), "num": x + 1} + for x in range(1000) + ] self.col.batch_save(*data) - testData = self.col.query(sort='num') + testData = self.col.query(sort="num") self.assertEqual(len(testData), 1000) for x in range(1000): - self.assertEqual(testData[x]['_key'], str(x)) - self.assertEqual(testData[x]['data'], '#' + str(x + 1)) - self.assertEqual(testData[x]['num'], x + 1) + self.assertEqual(testData[x]["_key"], str(x)) + self.assertEqual(testData[x]["data"], "#" + str(x + 1)) + self.assertEqual(testData[x]["num"], x + 1) query = [{"query": {"num": x + 1}} for x in range(100)] testData = self.col.batch_find(*query) self.assertEqual(len(testData), 100) - testData.sort(key=lambda x: x[0]['num']) + testData.sort(key=lambda x: x[0]["num"]) for x in range(100): - self.assertEqual(testData[x][0]['_key'], str(x)) - self.assertEqual(testData[x][0]['data'], '#' + str(x + 1)) - self.assertEqual(testData[x][0]['num'], x + 1) + self.assertEqual(testData[x][0]["_key"], str(x)) + self.assertEqual(testData[x][0]["data"], "#" + str(x + 1)) + self.assertEqual(testData[x][0]["num"], x + 1) def tearDown(self): confs = self.service.kvstore - if 'test' in confs: - confs['test'].delete() + if "test" in confs: + confs["test"].delete() if __name__ == "__main__": diff --git a/tests/test_kvstore_conf.py b/tests/test_kvstore_conf.py index fe098f66a..78e2e67d5 100755 --- a/tests/test_kvstore_conf.py +++ b/tests/test_kvstore_conf.py @@ -22,73 +22,77 @@ class KVStoreConfTestCase(testlib.SDKTestCase): def setUp(self): super().setUp() - self.service.namespace['app'] = 'search' + self.service.namespace["app"] = "search" self.confs = self.service.kvstore - if ('test' in self.confs): - self.confs['test'].delete() + if "test" in self.confs: + self.confs["test"].delete() def test_owner_restriction(self): - self.service.kvstore_owner = 'admin' + self.service.kvstore_owner = "admin" self.assertRaises(client.HTTPError, lambda: self.confs.list()) - self.service.kvstore_owner = 'nobody' + self.service.kvstore_owner = "nobody" def test_create_delete_collection(self): - self.confs.create('test') - self.assertTrue('test' in self.confs) - self.confs['test'].delete() - self.assertTrue('test' not in self.confs) + self.confs.create("test") + self.assertTrue("test" in self.confs) + self.confs["test"].delete() + self.assertTrue("test" not in self.confs) def test_create_fields(self): - self.confs.create('test', accelerated_fields={'ind1':{'a':1}}, fields={'a':'number1'}) - self.assertEqual(self.confs['test']['field.a'], 'number1') - self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {"a": 1}) - self.confs['test'].delete() + self.confs.create( + "test", accelerated_fields={"ind1": {"a": 1}}, fields={"a": "number1"} + ) + self.assertEqual(self.confs["test"]["field.a"], "number1") + self.assertEqual(self.confs["test"]["accelerated_fields.ind1"], {"a": 1}) + self.confs["test"].delete() def test_update_collection(self): - self.confs.create('test') + self.confs.create("test") val = {"a": 1} - self.confs['test'].post(**{'accelerated_fields.ind1': json.dumps(val), 'field.a': 'number'}) - self.assertEqual(self.confs['test']['field.a'], 'number') - self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {"a": 1}) - self.confs['test'].delete() + self.confs["test"].post( + **{"accelerated_fields.ind1": json.dumps(val), "field.a": "number"} + ) + self.assertEqual(self.confs["test"]["field.a"], "number") + self.assertEqual(self.confs["test"]["accelerated_fields.ind1"], {"a": 1}) + self.confs["test"].delete() def test_update_accelerated_fields(self): - self.confs.create('test', accelerated_fields={'ind1':{'a':1}}) - self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {'a': 1}) + self.confs.create("test", accelerated_fields={"ind1": {"a": 1}}) + self.assertEqual(self.confs["test"]["accelerated_fields.ind1"], {"a": 1}) # update accelerated_field value - self.confs['test'].update_accelerated_field('ind1', {'a': -1}) - self.assertEqual(self.confs['test']['accelerated_fields.ind1'], {'a': -1}) - self.confs['test'].delete() + self.confs["test"].update_accelerated_field("ind1", {"a": -1}) + self.assertEqual(self.confs["test"]["accelerated_fields.ind1"], {"a": -1}) + self.confs["test"].delete() def test_update_fields(self): - self.confs.create('test') - self.confs['test'].post(**{'field.a': 'number'}) - self.assertEqual(self.confs['test']['field.a'], 'number') - self.confs['test'].update_field('a', 'string') - self.assertEqual(self.confs['test']['field.a'], 'string') - self.confs['test'].delete() + self.confs.create("test") + self.confs["test"].post(**{"field.a": "number"}) + self.assertEqual(self.confs["test"]["field.a"], "number") + self.confs["test"].update_field("a", "string") + self.assertEqual(self.confs["test"]["field.a"], "string") + self.confs["test"].delete() def test_create_unique_collection(self): - self.confs.create('test') - self.assertTrue('test' in self.confs) - self.assertRaises(client.HTTPError, lambda: self.confs.create('test')) - self.confs['test'].delete() + self.confs.create("test") + self.assertTrue("test" in self.confs) + self.assertRaises(client.HTTPError, lambda: self.confs.create("test")) + self.confs["test"].delete() def test_overlapping_collections(self): - self.service.namespace['app'] = 'system' - self.confs.create('test') - self.service.namespace['app'] = 'search' - self.confs.create('test') - self.assertEqual(self.confs['test']['eai:appName'], 'search') - self.service.namespace['app'] = 'system' - self.assertEqual(self.confs['test']['eai:appName'], 'system') - self.service.namespace['app'] = 'search' - self.confs['test'].delete() - self.confs['test'].delete() + self.service.namespace["app"] = "system" + self.confs.create("test") + self.service.namespace["app"] = "search" + self.confs.create("test") + self.assertEqual(self.confs["test"]["eai:appName"], "search") + self.service.namespace["app"] = "system" + self.assertEqual(self.confs["test"]["eai:appName"], "system") + self.service.namespace["app"] = "search" + self.confs["test"].delete() + self.confs["test"].delete() def tearDown(self): - if 'test' in self.confs: - self.confs['test'].delete() + if "test" in self.confs: + self.confs["test"].delete() if __name__ == "__main__": diff --git a/tests/test_kvstore_data.py b/tests/test_kvstore_data.py index 5860f6fcf..40c892644 100755 --- a/tests/test_kvstore_data.py +++ b/tests/test_kvstore_data.py @@ -23,72 +23,82 @@ class KVStoreDataTestCase(testlib.SDKTestCase): def setUp(self): super().setUp() - self.service.namespace['app'] = 'search' + self.service.namespace["app"] = "search" self.confs = self.service.kvstore - if ('test' in self.confs): - self.confs['test'].delete() - self.confs.create('test') + if "test" in self.confs: + self.confs["test"].delete() + self.confs.create("test") - self.col = self.confs['test'].data + self.col = self.confs["test"].data def test_insert_query_delete_data(self): for x in range(50): - self.col.insert(json.dumps({'_key': str(x), 'data': '#' + str(x), 'num': x})) + self.col.insert( + json.dumps({"_key": str(x), "data": "#" + str(x), "num": x}) + ) self.assertEqual(len(self.col.query()), 50) self.assertEqual(len(self.col.query(query='{"num": 10}')), 1) - self.assertEqual(self.col.query(query='{"num": 10}')[0]['data'], '#10') - self.col.delete(json.dumps({'num': {'$gt': 39}})) + self.assertEqual(self.col.query(query='{"num": 10}')[0]["data"], "#10") + self.col.delete(json.dumps({"num": {"$gt": 39}})) self.assertEqual(len(self.col.query()), 40) self.col.delete() self.assertEqual(len(self.col.query()), 0) def test_update_delete_data(self): for x in range(50): - self.col.insert(json.dumps({'_key': str(x), 'data': '#' + str(x), 'num': x})) + self.col.insert( + json.dumps({"_key": str(x), "data": "#" + str(x), "num": x}) + ) self.assertEqual(len(self.col.query()), 50) - self.assertEqual(self.col.query(query='{"num": 49}')[0]['data'], '#49') - self.col.update(str(49), json.dumps({'data': '#50', 'num': 50})) + self.assertEqual(self.col.query(query='{"num": 49}')[0]["data"], "#49") + self.col.update(str(49), json.dumps({"data": "#50", "num": 50})) self.assertEqual(len(self.col.query()), 50) - self.assertEqual(self.col.query(query='{"num": 50}')[0]['data'], '#50') + self.assertEqual(self.col.query(query='{"num": 50}')[0]["data"], "#50") self.assertEqual(len(self.col.query(query='{"num": 49}')), 0) self.col.delete_by_id(49) self.assertEqual(len(self.col.query(query='{"num": 50}')), 0) def test_query_data(self): - if 'test1' in self.confs: - self.confs['test1'].delete() - self.confs.create('test1') - self.col = self.confs['test1'].data + if "test1" in self.confs: + self.confs["test1"].delete() + self.confs.create("test1") + self.col = self.confs["test1"].data for x in range(10): - self.col.insert(json.dumps({'_key': str(x), 'data': '#' + str(x), 'num': x})) - data = self.col.query(sort='data:-1', skip=9) + self.col.insert( + json.dumps({"_key": str(x), "data": "#" + str(x), "num": x}) + ) + data = self.col.query(sort="data:-1", skip=9) self.assertEqual(len(data), 1) - self.assertEqual(data[0]['data'], '#0') - data = self.col.query(sort='data:1') - self.assertEqual(data[0]['data'], '#0') + self.assertEqual(data[0]["data"], "#0") + data = self.col.query(sort="data:1") + self.assertEqual(data[0]["data"], "#0") data = self.col.query(limit=2, skip=9) self.assertEqual(len(data), 1) def test_invalid_insert_update(self): - self.assertRaises(client.HTTPError, lambda: self.col.insert('NOT VALID DATA')) - id = self.col.insert(json.dumps({'foo': 'bar'}))['_key'] - self.assertRaises(client.HTTPError, lambda: self.col.update(id, 'NOT VALID DATA')) - self.assertEqual(self.col.query_by_id(id)['foo'], 'bar') + self.assertRaises(client.HTTPError, lambda: self.col.insert("NOT VALID DATA")) + id = self.col.insert(json.dumps({"foo": "bar"}))["_key"] + self.assertRaises( + client.HTTPError, lambda: self.col.update(id, "NOT VALID DATA") + ) + self.assertEqual(self.col.query_by_id(id)["foo"], "bar") def test_params_data_type_conversion(self): - self.confs['test'].post(**{'field.data': 'number', 'accelerated_fields.data': '{"data": -1}'}) + self.confs["test"].post( + **{"field.data": "number", "accelerated_fields.data": '{"data": -1}'} + ) for x in range(50): - self.col.insert(json.dumps({'_key': str(x), 'data': str(x), 'ignore': x})) - data = self.col.query(sort='data:-1', limit=20, fields='data,_id:0', skip=10) + self.col.insert(json.dumps({"_key": str(x), "data": str(x), "ignore": x})) + data = self.col.query(sort="data:-1", limit=20, fields="data,_id:0", skip=10) self.assertEqual(len(data), 20) for x in range(20): - self.assertEqual(data[x]['data'], 39 - x) - self.assertTrue('ignore' not in data[x]) - self.assertTrue('_key' not in data[x]) + self.assertEqual(data[x]["data"], 39 - x) + self.assertTrue("ignore" not in data[x]) + self.assertTrue("_key" not in data[x]) def tearDown(self): - if 'test' in self.confs: - self.confs['test'].delete() + if "test" in self.confs: + self.confs["test"].delete() if __name__ == "__main__": diff --git a/tests/test_logger.py b/tests/test_logger.py index 46623e363..0dca55171 100755 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -24,25 +24,25 @@ class LoggerTestCase(testlib.SDKTestCase): def check_logger(self, logger): self.check_entity(logger) - self.assertTrue(logger['level'] in LEVELS) + self.assertTrue(logger["level"] in LEVELS) def test_read(self): for logger in self.service.loggers.list(count=10): self.check_logger(logger) def test_crud(self): - self.assertTrue('AuditLogger' in self.service.loggers) - logger = self.service.loggers['AuditLogger'] + self.assertTrue("AuditLogger" in self.service.loggers) + logger = self.service.loggers["AuditLogger"] - saved = logger['level'] + saved = logger["level"] for level in LEVELS: logger.update(level=level) logger.refresh() - self.assertEqual(self.service.loggers['AuditLogger']['level'], level) + self.assertEqual(self.service.loggers["AuditLogger"]["level"], level) logger.update(level=saved) logger.refresh() - self.assertEqual(self.service.loggers['AuditLogger']['level'], saved) + self.assertEqual(self.service.loggers["AuditLogger"]["level"], saved) if __name__ == "__main__": diff --git a/tests/test_macro.py b/tests/test_macro.py index 25b72e4d9..52debc571 100755 --- a/tests/test_macro.py +++ b/tests/test_macro.py @@ -22,6 +22,7 @@ import pytest + @pytest.mark.smoke class TestMacro(testlib.SDKTestCase): def setUp(self): @@ -35,70 +36,66 @@ def setUp(self): def tearDown(self): super(TestMacro, self).setUp() for macro in self.service.macros: - if macro.name.startswith('delete-me'): + if macro.name.startswith("delete-me"): self.service.macros.delete(macro.name) def check_macro(self, macro): self.check_entity(macro) - expected_fields = ['definition', - 'iseval', - 'args', - 'validation', - 'errormsg'] + expected_fields = ["definition", "iseval", "args", "validation", "errormsg"] for f in expected_fields: macro[f] is_eval = macro.iseval - self.assertTrue(is_eval == '1' or is_eval == '0') + self.assertTrue(is_eval == "1" or is_eval == "0") def test_create(self): self.assertTrue(self.macro_name in self.service.macros) self.check_macro(self.macro) def test_create_with_args(self): - macro_name = testlib.tmpname() + '(1)' + macro_name = testlib.tmpname() + "(1)" definition = '| eval value="$value$"' kwargs = { - 'args': 'value', - 'validation': '$value$ > 10', - 'errormsg': 'value must be greater than 10' + "args": "value", + "validation": "$value$ > 10", + "errormsg": "value must be greater than 10", } macro = self.service.macros.create(macro_name, definition=definition, **kwargs) self.assertTrue(macro_name in self.service.macros) self.check_macro(macro) - self.assertEqual(macro.iseval, '0') - self.assertEqual(macro.args, kwargs.get('args')) - self.assertEqual(macro.validation, kwargs.get('validation')) - self.assertEqual(macro.errormsg, kwargs.get('errormsg')) + self.assertEqual(macro.iseval, "0") + self.assertEqual(macro.args, kwargs.get("args")) + self.assertEqual(macro.validation, kwargs.get("validation")) + self.assertEqual(macro.errormsg, kwargs.get("errormsg")) self.service.macros.delete(macro_name) def test_delete(self): self.assertTrue(self.macro_name in self.service.macros) self.service.macros.delete(self.macro_name) self.assertFalse(self.macro_name in self.service.macros) - self.assertRaises(client.HTTPError, - self.macro.refresh) + self.assertRaises(client.HTTPError, self.macro.refresh) def test_update(self): new_definition = '| eval updated="true"' self.macro.update(definition=new_definition) self.macro.refresh() - self.assertEqual(self.macro['definition'], new_definition) + self.assertEqual(self.macro["definition"], new_definition) - is_eval = testlib.to_bool(self.macro['iseval']) + is_eval = testlib.to_bool(self.macro["iseval"]) self.macro.update(iseval=not is_eval) self.macro.refresh() - self.assertEqual(testlib.to_bool(self.macro['iseval']), not is_eval) + self.assertEqual(testlib.to_bool(self.macro["iseval"]), not is_eval) def test_cannot_update_name(self): - new_name = self.macro_name + '-alteration' - self.assertRaises(client.IllegalOperationException, - self.macro.update, name=new_name) + new_name = self.macro_name + "-alteration" + self.assertRaises( + client.IllegalOperationException, self.macro.update, name=new_name + ) def test_name_collision(self): opts = self.opts.kwargs.copy() - opts['owner'] = '-' - opts['app'] = '-' - opts['sharing'] = 'user' + opts["owner"] = "-" + opts["app"] = "-" + opts["sharing"] = "user" service = client.connect(**opts) logging.debug("Namespace for collision testing: %s", service.namespace) macros = service.macros @@ -106,45 +103,43 @@ def test_name_collision(self): dispatch1 = '| eval macro_one="1"' dispatch2 = '| eval macro_two="2"' - namespace1 = client.namespace(app='search', sharing='app') - namespace2 = client.namespace(owner='admin', app='search', sharing='user') - new_macro2 = macros.create( - name, dispatch2, - namespace=namespace1) - new_macro1 = macros.create( - name, dispatch1, - namespace=namespace2) - - self.assertRaises(client.AmbiguousReferenceException, - macros.__getitem__, name) + namespace1 = client.namespace(app="search", sharing="app") + namespace2 = client.namespace(owner="admin", app="search", sharing="user") + new_macro2 = macros.create(name, dispatch2, namespace=namespace1) + new_macro1 = macros.create(name, dispatch1, namespace=namespace2) + + self.assertRaises(client.AmbiguousReferenceException, macros.__getitem__, name) macro1 = macros[name, namespace1] self.check_macro(macro1) - macro1.update(**{'definition': '| eval number=1'}) + macro1.update(**{"definition": "| eval number=1"}) macro1.refresh() - self.assertEqual(macro1['definition'], '| eval number=1') + self.assertEqual(macro1["definition"], "| eval number=1") macro2 = macros[name, namespace2] - macro2.update(**{'definition': '| eval number=2'}) + macro2.update(**{"definition": "| eval number=2"}) macro2.refresh() - self.assertEqual(macro2['definition'], '| eval number=2') + self.assertEqual(macro2["definition"], "| eval number=2") self.check_macro(macro2) def test_no_equality(self): - self.assertRaises(client.IncomparableException, - self.macro.__eq__, self.macro) + self.assertRaises(client.IncomparableException, self.macro.__eq__, self.macro) def test_acl(self): self.assertEqual(self.macro.access["perms"], None) - self.macro.acl_update(sharing="app", owner="admin", **{"perms.read": "admin, nobody"}) + self.macro.acl_update( + sharing="app", owner="admin", **{"perms.read": "admin, nobody"} + ) self.assertEqual(self.macro.access["owner"], "admin") self.assertEqual(self.macro.access["sharing"], "app") - self.assertEqual(self.macro.access["perms"]["read"], ['admin', 'nobody']) + self.assertEqual(self.macro.access["perms"]["read"], ["admin", "nobody"]) def test_acl_fails_without_sharing(self): self.assertRaisesRegex( ValueError, "Required argument 'sharing' is missing.", self.macro.acl_update, - owner="admin", app="search", **{"perms.read": "admin, nobody"} + owner="admin", + app="search", + **{"perms.read": "admin, nobody"}, ) def test_acl_fails_without_owner(self): @@ -152,7 +147,9 @@ def test_acl_fails_without_owner(self): ValueError, "Required argument 'owner' is missing.", self.macro.acl_update, - sharing="app", app="search", **{"perms.read": "admin, nobody"} + sharing="app", + app="search", + **{"perms.read": "admin, nobody"}, ) diff --git a/tests/test_message.py b/tests/test_message.py index 29f6a8694..b4026a00e 100755 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -24,8 +24,8 @@ def setUp(self): testlib.SDKTestCase.setUp(self) self.message_name = testlib.tmpname() self.message = self.service.messages.create( - self.message_name, - value='Test message created by the SDK') + self.message_name, value="Test message created by the SDK" + ) def tearDown(self): testlib.SDKTestCase.tearDown(self) @@ -35,9 +35,8 @@ def tearDown(self): class TestCreateDelete(testlib.SDKTestCase): def test_create_delete(self): message_name = testlib.tmpname() - message_value = 'Test message' - message = self.service.messages.create( - message_name, value=message_value) + message_value = "Test message" + message = self.service.messages.create(message_name, value=message_value) self.assertTrue(message_name in self.service.messages) self.assertEqual(message.value, message_value) self.check_entity(message) @@ -45,9 +44,24 @@ def test_create_delete(self): self.assertFalse(message_name in self.service.messages) def test_invalid_name(self): - self.assertRaises(client.InvalidNameException, self.service.messages.create, None, value="What?") - self.assertRaises(client.InvalidNameException, self.service.messages.create, 42, value="Who, me?") - self.assertRaises(client.InvalidNameException, self.service.messages.create, [1, 2, 3], value="Who, me?") + self.assertRaises( + client.InvalidNameException, + self.service.messages.create, + None, + value="What?", + ) + self.assertRaises( + client.InvalidNameException, + self.service.messages.create, + 42, + value="Who, me?", + ) + self.assertRaises( + client.InvalidNameException, + self.service.messages.create, + [1, 2, 3], + value="Who, me?", + ) if __name__ == "__main__": diff --git a/tests/test_modular_input.py b/tests/test_modular_input.py index 50a49d230..5e92359b7 100755 --- a/tests/test_modular_input.py +++ b/tests/test_modular_input.py @@ -32,22 +32,22 @@ def test_lists_modular_inputs(self): self.uncheckedRestartSplunk() inputs = self.service.inputs - if ('abcd', 'test2') not in inputs: - inputs.create('abcd', 'test2', field1='boris') + if ("abcd", "test2") not in inputs: + inputs.create("abcd", "test2", field1="boris") - input = inputs['abcd', 'test2'] - self.assertEqual(input.field1, 'boris') + input = inputs["abcd", "test2"] + self.assertEqual(input.field1, "boris") for m in self.service.modular_input_kinds: self.check_modular_input_kind(m) def check_modular_input_kind(self, m): print(m.name) - if m.name == 'test1': - self.assertEqual('Test "Input" - 1', m['title']) - self.assertEqual("xml", m['streaming_mode']) - elif m.name == 'test2': - self.assertEqual('test2', m['title']) - self.assertEqual('simple', m['streaming_mode']) + if m.name == "test1": + self.assertEqual('Test "Input" - 1', m["title"]) + self.assertEqual("xml", m["streaming_mode"]) + elif m.name == "test2": + self.assertEqual("test2", m["title"]) + self.assertEqual("simple", m["streaming_mode"]) if __name__ == "__main__": diff --git a/tests/test_modular_input_kinds.py b/tests/test_modular_input_kinds.py index 1304f269f..5f33d2a08 100755 --- a/tests/test_modular_input_kinds.py +++ b/tests/test_modular_input_kinds.py @@ -34,11 +34,22 @@ def test_list_arguments(self): # Not implemented before 5.0 return - test1 = self.service.modular_input_kinds['test1'] - - expected_args = {"name", "resname", "key_id", "no_description", "empty_description", "arg_required_on_edit", - "not_required_on_edit", "required_on_create", "not_required_on_create", "number_field", - "string_field", "boolean_field"} + test1 = self.service.modular_input_kinds["test1"] + + expected_args = { + "name", + "resname", + "key_id", + "no_description", + "empty_description", + "arg_required_on_edit", + "not_required_on_edit", + "required_on_create", + "not_required_on_create", + "number_field", + "string_field", + "boolean_field", + } found_args = set(test1.arguments.keys()) self.assertEqual(expected_args, found_args) @@ -51,16 +62,16 @@ def test_update_raises_exception(self): # Not implemented before 5.0 return - test1 = self.service.modular_input_kinds['test1'] + test1 = self.service.modular_input_kinds["test1"] self.assertRaises(client.IllegalOperationException, test1.update, a="b") def check_modular_input_kind(self, m): - if m.name == 'test1': - self.assertEqual('Test "Input" - 1', m['title']) - self.assertEqual("xml", m['streaming_mode']) - elif m.name == 'test2': - self.assertEqual('test2', m['title']) - self.assertEqual('simple', m['streaming_mode']) + if m.name == "test1": + self.assertEqual('Test "Input" - 1', m["title"]) + self.assertEqual("xml", m["streaming_mode"]) + elif m.name == "test2": + self.assertEqual("test2", m["title"]) + self.assertEqual("simple", m["streaming_mode"]) @pytest.mark.app def test_list_modular_inputs(self): diff --git a/tests/test_results.py b/tests/test_results.py index bde1c4ab4..5e82cb676 100755 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -27,7 +27,16 @@ def test_read_from_empty_result_set(self): job = self.service.jobs.create("search index=_internal_does_not_exist | head 2") while not job.is_done(): sleep(0.5) - self.assertEqual(0, len(list(results.JSONResultsReader(io.BufferedReader(job.results(output_mode='json')))))) + self.assertEqual( + 0, + len( + list( + results.JSONResultsReader( + io.BufferedReader(job.results(output_mode="json")) + ) + ) + ), + ) def test_read_normal_results(self): xml_text = """ @@ -86,27 +95,30 @@ def test_read_normal_results(self): """.strip() expected_results = [ - results.Message('DEBUG', 'base lispy: [ AND ]'), - results.Message('DEBUG', "search context: user='admin', app='search', bs-pathname='/some/path'"), + results.Message("DEBUG", "base lispy: [ AND ]"), + results.Message( + "DEBUG", + "search context: user='admin', app='search', bs-pathname='/some/path'", + ), { - 'series': 'twitter', - 'sum(kb)': '14372242.758775', + "series": "twitter", + "sum(kb)": "14372242.758775", }, { - 'series': 'splunkd', - 'sum(kb)': '267802.333926', + "series": "splunkd", + "sum(kb)": "267802.333926", }, { - 'series': 'flurry', - 'sum(kb)': '12576.454102', + "series": "flurry", + "sum(kb)": "12576.454102", }, { - 'series': 'splunkd_access', - 'sum(kb)': '5979.036338', + "series": "splunkd_access", + "sum(kb)": "5979.036338", }, { - 'series': 'splunk_web_access', - 'sum(kb)': '5838.935649', + "series": "splunk_web_access", + "sum(kb)": "5838.935649", }, ] @@ -128,7 +140,7 @@ def test_read_raw_field(self): """.strip() expected_results = [ { - '_raw': '07-13-2012 09:27:27.307 -0700 INFO Metrics - group=search_concurrency, system total, active_hist_searches=0, active_realtime_searches=0', + "_raw": "07-13-2012 09:27:27.307 -0700 INFO Metrics - group=search_concurrency, system total, active_hist_searches=0, active_realtime_searches=0", }, ] @@ -150,14 +162,14 @@ def test_read_raw_field_with_segmentation(self): """.strip() expected_results = [ { - '_raw': '07-13-2012 09:27:27.307 -0700 INFO Metrics - group=search_concurrency, system total, active_hist_searches=0, active_realtime_searches=0', + "_raw": "07-13-2012 09:27:27.307 -0700 INFO Metrics - group=search_concurrency, system total, active_hist_searches=0, active_realtime_searches=0", }, ] self.assert_parsed_results_equals(xml_text, expected_results) def assert_parsed_results_equals(self, xml_text, expected_results): - results_reader = results.ResultsReader(BytesIO(xml_text.encode('utf-8'))) + results_reader = results.ResultsReader(BytesIO(xml_text.encode("utf-8"))) actual_results = list(results_reader) self.assertEqual(expected_results, actual_results) diff --git a/tests/test_role.py b/tests/test_role.py index 2c087603f..768787204 100755 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -29,7 +29,7 @@ def setUp(self): def tearDown(self): super().tearDown() for role in self.service.roles: - if role.name.startswith('delete-me'): + if role.name.startswith("delete-me"): self.service.roles.delete(role.name) def check_role(self, role): @@ -61,48 +61,52 @@ def test_delete(self): self.assertRaises(client.HTTPError, self.role.refresh) def test_grant_and_revoke(self): - self.assertFalse('edit_user' in self.role.capabilities) - self.role.grant('edit_user') + self.assertFalse("edit_user" in self.role.capabilities) + self.role.grant("edit_user") self.role.refresh() - self.assertTrue('edit_user' in self.role.capabilities) + self.assertTrue("edit_user" in self.role.capabilities) - self.assertFalse('change_own_password' in self.role.capabilities) - self.role.grant('change_own_password') + self.assertFalse("change_own_password" in self.role.capabilities) + self.role.grant("change_own_password") self.role.refresh() - self.assertTrue('edit_user' in self.role.capabilities) - self.assertTrue('change_own_password' in self.role.capabilities) + self.assertTrue("edit_user" in self.role.capabilities) + self.assertTrue("change_own_password" in self.role.capabilities) - self.role.revoke('edit_user') + self.role.revoke("edit_user") self.role.refresh() - self.assertFalse('edit_user' in self.role.capabilities) - self.assertTrue('change_own_password' in self.role.capabilities) + self.assertFalse("edit_user" in self.role.capabilities) + self.assertTrue("change_own_password" in self.role.capabilities) - self.role.revoke('change_own_password') + self.role.revoke("change_own_password") self.role.refresh() - self.assertFalse('edit_user' in self.role.capabilities) - self.assertFalse('change_own_password' in self.role.capabilities) + self.assertFalse("edit_user" in self.role.capabilities) + self.assertFalse("change_own_password" in self.role.capabilities) def test_invalid_grant(self): - self.assertRaises(client.NoSuchCapability, self.role.grant, 'i-am-an-invalid-capability') + self.assertRaises( + client.NoSuchCapability, self.role.grant, "i-am-an-invalid-capability" + ) def test_invalid_revoke(self): - self.assertRaises(client.NoSuchCapability, self.role.revoke, 'i-am-an-invalid-capability') + self.assertRaises( + client.NoSuchCapability, self.role.revoke, "i-am-an-invalid-capability" + ) def test_revoke_capability_not_granted(self): - self.role.revoke('change_own_password') + self.role.revoke("change_own_password") def test_update(self): kwargs = {} - if 'user' in self.role['imported_roles']: - kwargs['imported_roles'] = '' + if "user" in self.role["imported_roles"]: + kwargs["imported_roles"] = "" else: - kwargs['imported_roles'] = ['user'] - if self.role['srchJobsQuota'] is not None: - kwargs['srchJobsQuota'] = int(self.role['srchJobsQuota']) + 1 + kwargs["imported_roles"] = ["user"] + if self.role["srchJobsQuota"] is not None: + kwargs["srchJobsQuota"] = int(self.role["srchJobsQuota"]) + 1 self.role.update(**kwargs) self.role.refresh() - self.assertEqual(self.role['imported_roles'], kwargs['imported_roles']) - self.assertEqual(int(self.role['srchJobsQuota']), kwargs['srchJobsQuota']) + self.assertEqual(self.role["imported_roles"], kwargs["imported_roles"]) + self.assertEqual(int(self.role["srchJobsQuota"]), kwargs["srchJobsQuota"]) if __name__ == "__main__": diff --git a/tests/test_saved_search.py b/tests/test_saved_search.py index a78d9420c..39d3c6517 100755 --- a/tests/test_saved_search.py +++ b/tests/test_saved_search.py @@ -24,7 +24,6 @@ from splunklib import client - @pytest.mark.smoke class TestSavedSearch(testlib.SDKTestCase): def setUp(self): @@ -38,7 +37,7 @@ def setUp(self): def tearDown(self): super().setUp() for saved_search in self.service.saved_searches: - if saved_search.name.startswith('delete-me'): + if saved_search.name.startswith("delete-me"): try: for job in saved_search.history(): job.cancel() @@ -48,36 +47,38 @@ def tearDown(self): def check_saved_search(self, saved_search): self.check_entity(saved_search) - expected_fields = ['alert.expires', - 'alert.severity', - 'alert.track', - 'alert_type', - 'dispatch.buckets', - 'dispatch.lookups', - 'dispatch.max_count', - 'dispatch.max_time', - 'dispatch.reduce_freq', - 'dispatch.spawn_process', - 'dispatch.time_format', - 'dispatch.ttl', - 'max_concurrent', - 'realtime_schedule', - 'restart_on_searchpeer_add', - 'run_on_startup', - 'search', - 'action.email', - 'action.populate_lookup', - 'action.rss', - 'action.script', - 'action.summary_index'] + expected_fields = [ + "alert.expires", + "alert.severity", + "alert.track", + "alert_type", + "dispatch.buckets", + "dispatch.lookups", + "dispatch.max_count", + "dispatch.max_time", + "dispatch.reduce_freq", + "dispatch.spawn_process", + "dispatch.time_format", + "dispatch.ttl", + "max_concurrent", + "realtime_schedule", + "restart_on_searchpeer_add", + "run_on_startup", + "search", + "action.email", + "action.populate_lookup", + "action.rss", + "action.script", + "action.summary_index", + ] for f in expected_fields: saved_search[f] self.assertGreaterEqual(saved_search.suppressed, 0) - self.assertGreaterEqual(saved_search['suppressed'], 0) - is_scheduled = saved_search.content['is_scheduled'] - self.assertTrue(is_scheduled in ('1', '0')) - is_visible = saved_search.content['is_visible'] - self.assertTrue(is_visible in ('1', '0')) + self.assertGreaterEqual(saved_search["suppressed"], 0) + is_scheduled = saved_search.content["is_scheduled"] + self.assertTrue(is_scheduled in ("1", "0")) + is_visible = saved_search.content["is_visible"] + self.assertTrue(is_visible in ("1", "0")) def test_create(self): self.assertTrue(self.saved_search_name in self.service.saved_searches) @@ -87,52 +88,51 @@ def test_delete(self): self.assertTrue(self.saved_search_name in self.service.saved_searches) self.service.saved_searches.delete(self.saved_search_name) self.assertFalse(self.saved_search_name in self.service.saved_searches) - self.assertRaises(client.HTTPError, - self.saved_search.refresh) + self.assertRaises(client.HTTPError, self.saved_search.refresh) def test_update(self): - is_visible = testlib.to_bool(self.saved_search['is_visible']) + is_visible = testlib.to_bool(self.saved_search["is_visible"]) self.saved_search.update(is_visible=not is_visible) self.saved_search.refresh() - self.assertEqual(testlib.to_bool(self.saved_search['is_visible']), not is_visible) + self.assertEqual( + testlib.to_bool(self.saved_search["is_visible"]), not is_visible + ) def test_cannot_update_name(self): - new_name = self.saved_search_name + '-alteration' - self.assertRaises(client.IllegalOperationException, - self.saved_search.update, name=new_name) + new_name = self.saved_search_name + "-alteration" + self.assertRaises( + client.IllegalOperationException, self.saved_search.update, name=new_name + ) def test_name_collision(self): opts = self.opts.kwargs.copy() - opts['owner'] = '-' - opts['app'] = '-' - opts['sharing'] = 'user' + opts["owner"] = "-" + opts["app"] = "-" + opts["sharing"] = "user" service = client.connect(**opts) logging.debug("Namespace for collision testing: %s", service.namespace) saved_searches = service.saved_searches name = testlib.tmpname() - query1 = '* earliest=-1m | head 1' - query2 = '* earliest=-2m | head 2' - namespace1 = client.namespace(app='search', sharing='app') - namespace2 = client.namespace(owner='admin', app='search', sharing='user') - saved_search2 = saved_searches.create( - name, query2, - namespace=namespace1) - saved_search1 = saved_searches.create( - name, query1, - namespace=namespace2) - - self.assertRaises(client.AmbiguousReferenceException, - saved_searches.__getitem__, name) + query1 = "* earliest=-1m | head 1" + query2 = "* earliest=-2m | head 2" + namespace1 = client.namespace(app="search", sharing="app") + namespace2 = client.namespace(owner="admin", app="search", sharing="user") + saved_search2 = saved_searches.create(name, query2, namespace=namespace1) + saved_search1 = saved_searches.create(name, query1, namespace=namespace2) + + self.assertRaises( + client.AmbiguousReferenceException, saved_searches.__getitem__, name + ) search1 = saved_searches[name, namespace1] self.check_saved_search(search1) - search1.update(**{'action.email.from': 'nobody@nowhere.com'}) + search1.update(**{"action.email.from": "nobody@nowhere.com"}) search1.refresh() - self.assertEqual(search1['action.email.from'], 'nobody@nowhere.com') + self.assertEqual(search1["action.email.from"], "nobody@nowhere.com") search2 = saved_searches[name, namespace2] - search2.update(**{'action.email.from': 'nemo@utopia.com'}) + search2.update(**{"action.email.from": "nemo@utopia.com"}) search2.refresh() - self.assertEqual(search2['action.email.from'], 'nemo@utopia.com') + self.assertEqual(search2["action.email.from"], "nemo@utopia.com") self.check_saved_search(search2) def test_dispatch(self): @@ -146,7 +146,7 @@ def test_dispatch(self): def test_dispatch_with_options(self): try: - kwargs = {'dispatch.buckets': 100} + kwargs = {"dispatch.buckets": 100} job = self.saved_search.dispatch(**kwargs) while not job.is_ready(): sleep(0.1) @@ -190,46 +190,53 @@ def test_history_with_options(self): job.cancel() def test_scheduled_times(self): - self.saved_search.update(cron_schedule='*/5 * * * *', is_scheduled=True) + self.saved_search.update(cron_schedule="*/5 * * * *", is_scheduled=True) scheduled_times = self.saved_search.scheduled_times() logging.debug("Scheduled times: %s", scheduled_times) - self.assertTrue(all([isinstance(x, datetime.datetime) - for x in scheduled_times])) + self.assertTrue( + all([isinstance(x, datetime.datetime) for x in scheduled_times]) + ) time_pairs = list(zip(scheduled_times[:-1], scheduled_times[1:])) for earlier, later in time_pairs: diff = later - earlier self.assertEqual(diff.total_seconds() / 60.0, 5) def test_no_equality(self): - self.assertRaises(client.IncomparableException, - self.saved_search.__eq__, self.saved_search) + self.assertRaises( + client.IncomparableException, self.saved_search.__eq__, self.saved_search + ) def test_suppress(self): - suppressed_time = self.saved_search['suppressed'] + suppressed_time = self.saved_search["suppressed"] self.assertGreaterEqual(suppressed_time, 0) new_suppressed_time = suppressed_time + 100 self.saved_search.suppress(new_suppressed_time) - self.assertLessEqual(self.saved_search['suppressed'], - new_suppressed_time) - self.assertGreater(self.saved_search['suppressed'], - suppressed_time) + self.assertLessEqual(self.saved_search["suppressed"], new_suppressed_time) + self.assertGreater(self.saved_search["suppressed"], suppressed_time) self.saved_search.unsuppress() - self.assertEqual(self.saved_search['suppressed'], 0) + self.assertEqual(self.saved_search["suppressed"], 0) def test_acl(self): self.assertEqual(self.saved_search.access["perms"], None) - self.saved_search.acl_update(sharing="app", owner="admin", app="search", **{"perms.read": "admin, nobody"}) + self.saved_search.acl_update( + sharing="app", + owner="admin", + app="search", + **{"perms.read": "admin, nobody"}, + ) self.assertEqual(self.saved_search.access["owner"], "admin") self.assertEqual(self.saved_search.access["app"], "search") self.assertEqual(self.saved_search.access["sharing"], "app") - self.assertEqual(self.saved_search.access["perms"]["read"], ['admin', 'nobody']) + self.assertEqual(self.saved_search.access["perms"]["read"], ["admin", "nobody"]) def test_acl_fails_without_sharing(self): self.assertRaisesRegex( ValueError, "Required argument 'sharing' is missing.", self.saved_search.acl_update, - owner="admin", app="search", **{"perms.read": "admin, nobody"} + owner="admin", + app="search", + **{"perms.read": "admin, nobody"}, ) def test_acl_fails_without_owner(self): @@ -237,9 +244,12 @@ def test_acl_fails_without_owner(self): ValueError, "Required argument 'owner' is missing.", self.saved_search.acl_update, - sharing="app", app="search", **{"perms.read": "admin, nobody"} + sharing="app", + app="search", + **{"perms.read": "admin, nobody"}, ) + if __name__ == "__main__": import unittest diff --git a/tests/test_service.py b/tests/test_service.py index 6433b56b5..971764c0e 100755 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -25,7 +25,6 @@ class ServiceTestCase(testlib.SDKTestCase): - def test_autologin(self): service = client.connect(autologin=True, **self.opts.kwargs) self.service.restart(timeout=120) @@ -36,13 +35,29 @@ def test_capabilities(self): capabilities = self.service.capabilities self.assertTrue(isinstance(capabilities, list)) self.assertTrue(all([isinstance(c, str) for c in capabilities])) - self.assertTrue('change_own_password' in capabilities) # This should always be there... + self.assertTrue( + "change_own_password" in capabilities + ) # This should always be there... def test_info(self): info = self.service.info - keys = ["build", "cpu_arch", "guid", "isFree", "isTrial", "licenseKeys", - "licenseSignature", "licenseState", "master_guid", "mode", - "os_build", "os_name", "os_version", "serverName", "version"] + keys = [ + "build", + "cpu_arch", + "guid", + "isFree", + "isTrial", + "licenseKeys", + "licenseSignature", + "licenseState", + "master_guid", + "mode", + "os_build", + "os_name", + "os_version", + "serverName", + "version", + ] for key in keys: self.assertTrue(key in list(info.keys())) @@ -55,7 +70,7 @@ def test_info_with_namespace(self): self.service.namespace["owner"] = self.service.username self.service.namespace["app"] = "search" try: - self.assertEqual(self.service.info.licenseState, 'OK') + self.assertEqual(self.service.info.licenseState, "OK") except HTTPError as he: self.fail(f"Couldn't get the server info, probably got a 403! {he.message}") @@ -68,31 +83,31 @@ def test_without_namespace(self): def test_app_namespace(self): kwargs = self.opts.kwargs.copy() - kwargs.update({'app': "search", 'owner': None}) + kwargs.update({"app": "search", "owner": None}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_owner_wildcard(self): kwargs = self.opts.kwargs.copy() - kwargs.update({'app': "search", 'owner': "-"}) + kwargs.update({"app": "search", "owner": "-"}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_default_app(self): kwargs = self.opts.kwargs.copy() - kwargs.update({'app': None, 'owner': "admin"}) + kwargs.update({"app": None, "owner": "admin"}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_app_wildcard(self): kwargs = self.opts.kwargs.copy() - kwargs.update({'app': "-", 'owner': "admin"}) + kwargs.update({"app": "-", "owner": "admin"}) service_ns = client.connect(**kwargs) service_ns.apps.list() def test_user_namespace(self): kwargs = self.opts.kwargs.copy() - kwargs.update({'app': "search", 'owner': "admin"}) + kwargs.update({"app": "search", "owner": "admin"}) service_ns = client.connect(**kwargs) service_ns.apps.list() @@ -107,7 +122,7 @@ def test_parse(self): def test_parse_fail(self): try: self.service.parse("xyzzy") - self.fail('Parse on nonsense did not fail') + self.fail("Parse on nonsense did not fail") except HTTPError as e: self.assertEqual(e.status, 400) @@ -119,14 +134,14 @@ def test_restart(self): def test_read_outputs_with_type(self): name = testlib.tmpname() service = client.connect(**self.opts.kwargs) - service.post('data/outputs/tcp/syslog', name=name, type='tcp') - entity = client.Entity(service, 'data/outputs/tcp/syslog/' + name) - self.assertTrue('tcp', entity.content.type) + service.post("data/outputs/tcp/syslog", name=name, type="tcp") + entity = client.Entity(service, "data/outputs/tcp/syslog/" + name) + self.assertTrue("tcp", entity.content.type) if service.restart_required: self.restartSplunk() service = client.connect(**self.opts.kwargs) - client.Entity(service, 'data/outputs/tcp/syslog/' + name).delete() + client.Entity(service, "data/outputs/tcp/syslog/" + name).delete() if service.restart_required: self.restartSplunk() @@ -152,7 +167,7 @@ def test_query_without_login_raises_http_401(self): service = self._create_unauthenticated_service() try: service.indexes.list() - self.fail('Expected HTTP 401.') + self.fail("Expected HTTP 401.") except HTTPError as he: if he.status == 401: # Good @@ -161,19 +176,28 @@ def test_query_without_login_raises_http_401(self): raise def _create_unauthenticated_service(self): - return Service(**{ - 'host': self.opts.kwargs['host'], - 'port': self.opts.kwargs['port'], - 'scheme': self.opts.kwargs['scheme'] - }) + return Service( + **{ + "host": self.opts.kwargs["host"], + "port": self.opts.kwargs["port"], + "scheme": self.opts.kwargs["scheme"], + } + ) # To check the HEC event endpoint using Endpoint instance @pytest.mark.smoke def test_hec_event(self): import json - service_hec = client.connect(host='localhost', scheme='https', port=8088, - token="11111111-1111-1111-1111-1111111111113") - event_collector_endpoint = client.Endpoint(service_hec, "/services/collector/event") + + service_hec = client.connect( + host="localhost", + scheme="https", + port=8088, + token="11111111-1111-1111-1111-1111111111113", + ) + event_collector_endpoint = client.Endpoint( + service_hec, "/services/collector/event" + ) msg = {"index": "main", "event": "Hello World"} response = event_collector_endpoint.post("", body=json.dumps(msg)) self.assertEqual(response.status, 200) @@ -184,11 +208,11 @@ def setUp(self): self.opts = testlib.parse([], {}, ".env") self.service = client.Service(**self.opts.kwargs) - if getattr(unittest.TestCase, 'assertIsNotNone', None) is None: + if getattr(unittest.TestCase, "assertIsNotNone", None) is None: def assertIsNotNone(self, obj, msg=None): if obj is None: - raise self.failureException(msg or f'{obj} is not None') + raise self.failureException(msg or f"{obj} is not None") def test_login_and_store_cookie(self): self.assertIsNotNone(self.service.get_cookies()) @@ -202,7 +226,9 @@ def test_login_with_cookie(self): self.service.login() self.assertIsNotNone(self.service.get_cookies()) # Use the cookie from the other service as the only auth param (don't need user/password) - service2 = client.Service(**{"cookie": "%s=%s" % list(self.service.get_cookies().items())[0]}) + service2 = client.Service( + **{"cookie": "%s=%s" % list(self.service.get_cookies().items())[0]} + ) service2.login() self.assertEqual(len(service2.get_cookies()), 1) self.assertEqual(service2.get_cookies(), self.service.get_cookies()) @@ -211,11 +237,11 @@ def test_login_with_cookie(self): self.assertEqual(service2.apps.get().status, 200) def test_login_fails_with_bad_cookie(self): - bad_cookie = {'bad': 'cookie'} + bad_cookie = {"bad": "cookie"} service2 = client.Service() self.assertEqual(len(service2.get_cookies()), 0) service2.get_cookies().update(bad_cookie) - self.assertEqual(service2.get_cookies(), {'bad': 'cookie'}) + self.assertEqual(service2.get_cookies(), {"bad": "cookie"}) # Should get an error with a bad cookie try: @@ -230,7 +256,8 @@ def test_autologin_with_cookie(self): service = client.connect( autologin=True, cookie="%s=%s" % list(self.service.get_cookies().items())[0], - **self.opts.kwargs) + **self.opts.kwargs, + ) self.assertTrue(service.has_cookies()) self.service.restart(timeout=120) reader = service.jobs.oneshot("search index=internal | head 1") @@ -248,10 +275,7 @@ def test_login_fails_with_no_cookie(self): self.assertEqual(str(ae), "Login failed.") def test_login_with_multiple_cookie_headers(self): - cookies = { - 'bad': 'cookie', - 'something_else': 'bad' - } + cookies = {"bad": "cookie", "something_else": "bad"} self.service.logout() self.service.get_cookies().update(cookies) @@ -259,7 +283,7 @@ def test_login_with_multiple_cookie_headers(self): self.assertEqual(self.service.apps.get().status, 200) def test_login_with_multiple_cookies(self): - bad_cookie = 'bad=cookie' + bad_cookie = "bad=cookie" self.service.login() self.assertIsNotNone(self.service.get_cookies()) @@ -276,13 +300,17 @@ def test_login_with_multiple_cookies(self): service2.get_cookies().update(self.service.get_cookies()) self.assertEqual(len(service2.get_cookies()), 2) - self.service.get_cookies().update({'bad': 'cookie'}) + self.service.get_cookies().update({"bad": "cookie"}) self.assertEqual(service2.get_cookies(), self.service.get_cookies()) self.assertEqual(len(service2.get_cookies()), 2) - self.assertTrue([cookie for cookie in service2.get_cookies() if "splunkd_" in cookie]) - self.assertTrue('bad' in service2.get_cookies()) - self.assertEqual(service2.get_cookies()['bad'], 'cookie') - self.assertEqual(set(self.service.get_cookies()), set(service2.get_cookies())) + self.assertTrue( + [cookie for cookie in service2.get_cookies() if "splunkd_" in cookie] + ) + self.assertTrue("bad" in service2.get_cookies()) + self.assertEqual(service2.get_cookies()["bad"], "cookie") + self.assertEqual( + set(self.service.get_cookies()), set(service2.get_cookies()) + ) service2.login() self.assertEqual(service2.apps.get().status, 200) @@ -292,9 +320,18 @@ def test_read_settings(self): settings = self.service.settings # Verify that settings contains the keys we expect keys = [ - "SPLUNK_DB", "SPLUNK_HOME", "enableSplunkWebSSL", "host", - "httpport", "mgmtHostPort", "minFreeSpace", "pass4SymmKey", - "serverName", "sessionTimeout", "startwebserver", "trustedIP" + "SPLUNK_DB", + "SPLUNK_HOME", + "enableSplunkWebSSL", + "host", + "httpport", + "mgmtHostPort", + "minFreeSpace", + "pass4SymmKey", + "serverName", + "sessionTimeout", + "startwebserver", + "trustedIP", ] for key in keys: self.assertTrue(key in settings) @@ -302,68 +339,78 @@ def test_read_settings(self): def test_update_settings(self): settings = self.service.settings # Verify that we can update the settings - original = settings['sessionTimeout'] + original = settings["sessionTimeout"] self.assertTrue(original != "42h") settings.update(sessionTimeout="42h") settings.refresh() - updated = settings['sessionTimeout'] + updated = settings["sessionTimeout"] self.assertEqual(updated, "42h") # Restore (and verify) original value settings.update(sessionTimeout=original) settings.refresh() - updated = settings['sessionTimeout'] + updated = settings["sessionTimeout"] self.assertEqual(updated, original) self.restartSplunk() class TestTrailing(unittest.TestCase): - template = '/servicesNS/boris/search/another/path/segment/that runs on' + template = "/servicesNS/boris/search/another/path/segment/that runs on" def test_raises_when_not_found_first(self): - self.assertRaises(ValueError, client._trailing, 'this is a test', 'boris') + self.assertRaises(ValueError, client._trailing, "this is a test", "boris") def test_raises_when_not_found_second(self): - self.assertRaises(ValueError, client._trailing, 'this is a test', 's is', 'boris') + self.assertRaises( + ValueError, client._trailing, "this is a test", "s is", "boris" + ) def test_no_args_is_identity(self): self.assertEqual(self.template, client._trailing(self.template)) def test_trailing_with_one_arg_works(self): - self.assertEqual('boris/search/another/path/segment/that runs on', - client._trailing(self.template, 'ervicesNS/')) + self.assertEqual( + "boris/search/another/path/segment/that runs on", + client._trailing(self.template, "ervicesNS/"), + ) def test_trailing_with_n_args_works(self): self.assertEqual( - 'another/path/segment/that runs on', - client._trailing(self.template, 'servicesNS/', '/', '/') + "another/path/segment/that runs on", + client._trailing(self.template, "servicesNS/", "/", "/"), ) class TestEntityNamespacing(testlib.SDKTestCase): def test_proper_namespace_with_arguments(self): - entity = self.service.apps['search'] - self.assertEqual((None, None, "global"), entity._proper_namespace(sharing="global")) - self.assertEqual((None, "search", "app"), entity._proper_namespace(sharing="app", app="search")) + entity = self.service.apps["search"] + self.assertEqual( + (None, None, "global"), entity._proper_namespace(sharing="global") + ) + self.assertEqual( + (None, "search", "app"), + entity._proper_namespace(sharing="app", app="search"), + ) self.assertEqual( ("admin", "search", "user"), - entity._proper_namespace(sharing="user", app="search", owner="admin") + entity._proper_namespace(sharing="user", app="search", owner="admin"), ) def test_proper_namespace_with_entity_namespace(self): - entity = self.service.apps['search'] + entity = self.service.apps["search"] namespace = (entity.access.owner, entity.access.app, entity.access.sharing) self.assertEqual(namespace, entity._proper_namespace()) def test_proper_namespace_with_service_namespace(self): entity = client.Entity(self.service, client.PATH_APPS + "search") - del entity._state['access'] - namespace = (self.service.namespace.owner, - self.service.namespace.app, - self.service.namespace.sharing) + del entity._state["access"] + namespace = ( + self.service.namespace.owner, + self.service.namespace.app, + self.service.namespace.sharing, + ) self.assertEqual(namespace, entity._proper_namespace()) if __name__ == "__main__": - unittest.main() diff --git a/tests/test_storage_passwords.py b/tests/test_storage_passwords.py index bda832dd9..972157835 100644 --- a/tests/test_storage_passwords.py +++ b/tests/test_storage_passwords.py @@ -100,14 +100,14 @@ def test_create_with_colons(self): username = testlib.tmpname() realm = testlib.tmpname() - p = self.storage_passwords.create("changeme", username + ":end", - ":start" + realm) + p = self.storage_passwords.create( + "changeme", username + ":end", ":start" + realm + ) self.assertEqual(start_count + 1, len(self.storage_passwords)) self.assertEqual(p.realm, ":start" + realm) self.assertEqual(p.username, username + ":end") # self.assertEqual(p.clear_password, "changeme") - self.assertEqual(p.name, - "\\:start" + realm + ":" + username + "\\:end:") + self.assertEqual(p.name, "\\:start" + realm + ":" + username + "\\:end:") p.delete() self.assertEqual(start_count, len(self.storage_passwords)) @@ -120,8 +120,9 @@ def test_create_with_colons(self): self.assertEqual(p.realm, realm) self.assertEqual(p.username, user) # self.assertEqual(p.clear_password, "changeme") - self.assertEqual(p.name, - prefix + "\\:r\\:e\\:a\\:l\\:m\\::\\:u\\:s\\:e\\:r\\::") + self.assertEqual( + p.name, prefix + "\\:r\\:e\\:a\\:l\\:m\\::\\:u\\:s\\:e\\:r\\::" + ) p.delete() self.assertEqual(start_count, len(self.storage_passwords)) @@ -131,15 +132,23 @@ def test_create_crazy(self): username = testlib.tmpname() realm = testlib.tmpname() - p = self.storage_passwords.create("changeme", - username + ":end!@#$%^&*()_+{}:|<>?", - ":start::!@#$%^&*()_+{}:|<>?" + realm) + p = self.storage_passwords.create( + "changeme", + username + ":end!@#$%^&*()_+{}:|<>?", + ":start::!@#$%^&*()_+{}:|<>?" + realm, + ) self.assertEqual(start_count + 1, len(self.storage_passwords)) self.assertEqual(p.realm, ":start::!@#$%^&*()_+{}:|<>?" + realm) self.assertEqual(p.username, username + ":end!@#$%^&*()_+{}:|<>?") # self.assertEqual(p.clear_password, "changeme") - self.assertEqual(p.name, - "\\:start\\:\\:!@#$%^&*()_+{}\\:|<>?" + realm + ":" + username + "\\:end!@#$%^&*()_+{}\\:|<>?:") + self.assertEqual( + p.name, + "\\:start\\:\\:!@#$%^&*()_+{}\\:|<>?" + + realm + + ":" + + username + + "\\:end!@#$%^&*()_+{}\\:|<>?:", + ) p.delete() self.assertEqual(start_count, len(self.storage_passwords)) @@ -205,15 +214,17 @@ def test_delete(self): self.assertEqual(start_count, len(self.storage_passwords)) # Test named parameters - self.storage_passwords.create(password="changeme", username=username, - realm="myrealm") + self.storage_passwords.create( + password="changeme", username=username, realm="myrealm" + ) self.assertEqual(start_count + 1, len(self.storage_passwords)) self.storage_passwords.delete(username, "myrealm") self.assertEqual(start_count, len(self.storage_passwords)) - self.storage_passwords.create(password="changeme", username=username + "/foo", - realm="/myrealm") + self.storage_passwords.create( + password="changeme", username=username + "/foo", realm="/myrealm" + ) self.assertEqual(start_count + 1, len(self.storage_passwords)) self.storage_passwords.delete(username + "/foo", "/myrealm") diff --git a/tests/test_user.py b/tests/test_user.py index e20b96946..94f525290 100755 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -23,20 +23,19 @@ class UserTestCase(testlib.SDKTestCase): def check_user(self, user): self.check_entity(user) # Verify expected fields exist - [user[f] for f in ['email', 'password', 'realname', 'roles']] + [user[f] for f in ["email", "password", "realname", "roles"]] def setUp(self): super().setUp() self.username = testlib.tmpname() self.user = self.service.users.create( - self.username, - password='changeme!', - roles=['power', 'user']) + self.username, password="changeme!", roles=["power", "user"] + ) def tearDown(self): super().tearDown() for user in self.service.users: - if user.name.startswith('delete-me'): + if user.name.startswith("delete-me"): self.service.users.delete(user.name) def test_read(self): @@ -45,8 +44,7 @@ def test_read(self): for role in user.role_entities: self.assertTrue(isinstance(role, client.Entity)) self.assertTrue(role.name in self.service.roles) - self.assertEqual(user.roles, - [role.name for role in user.role_entities]) + self.assertEqual(user.roles, [role.name for role in user.role_entities]) def test_create(self): self.assertTrue(self.username in self.service.users) @@ -59,10 +57,10 @@ def test_delete(self): self.user.refresh() def test_update(self): - self.assertTrue(self.user['email'] is None) + self.assertTrue(self.user["email"] is None) self.user.update(email="foo@bar.com") self.user.refresh() - self.assertTrue(self.user['email'] == "foo@bar.com") + self.assertTrue(self.user["email"] == "foo@bar.com") def test_in_is_case_insensitive(self): # Splunk lowercases user names, verify the casing works as expected diff --git a/tests/test_utils.py b/tests/test_utils.py index 40ed53197..d8feafa5b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,11 +4,11 @@ from utils import dslice TEST_DICT = { - 'username': 'admin', - 'password': 'changeme', - 'port': 8089, - 'host': 'localhost', - 'scheme': 'https' + "username": "admin", + "password": "changeme", + "port": 8089, + "host": "localhost", + "scheme": "https", } @@ -19,80 +19,65 @@ def setUp(self): # Test dslice when a dict is passed to change key names def test_dslice_dict_args(self): args = { - 'username': 'user-name', - 'password': 'new_password', - 'port': 'admin_port', - 'foo': 'bar' + "username": "user-name", + "password": "new_password", + "port": "admin_port", + "foo": "bar", } expected = { - 'user-name': 'admin', - 'new_password': 'changeme', - 'admin_port': 8089 + "user-name": "admin", + "new_password": "changeme", + "admin_port": 8089, } self.assertTrue(expected == dslice(TEST_DICT, args)) # Test dslice when a list is passed def test_dslice_list_args(self): - test_list = [ - 'username', - 'password', - 'port', - 'host', - 'foo' - ] + test_list = ["username", "password", "port", "host", "foo"] expected = { - 'username': 'admin', - 'password': 'changeme', - 'port': 8089, - 'host': 'localhost' + "username": "admin", + "password": "changeme", + "port": 8089, + "host": "localhost", } self.assertTrue(expected == dslice(TEST_DICT, test_list)) # Test dslice when a single string is passed def test_dslice_arg(self): - test_arg = 'username' - expected = { - 'username': 'admin' - } + test_arg = "username" + expected = {"username": "admin"} self.assertTrue(expected == dslice(TEST_DICT, test_arg)) # Test dslice using all three types of arguments def test_dslice_all_args(self): - test_args = [ - {'username': 'new_username'}, - ['password', - 'host'], - 'port' - ] + test_args = [{"username": "new_username"}, ["password", "host"], "port"] expected = { - 'new_username': 'admin', - 'password': 'changeme', - 'host': 'localhost', - 'port': 8089 + "new_username": "admin", + "password": "changeme", + "host": "localhost", + "port": 8089, } self.assertTrue(expected == dslice(TEST_DICT, *test_args)) class FilePermissionTest(unittest.TestCase): - def setUp(self): super().setUp() # Check for any change in the default file permission(i.e 644) for all files within splunklib def test_filePermissions(self): - def checkFilePermissions(dir_path): for file in os.listdir(dir_path): - if file.__contains__('pycache'): + if file.__contains__("pycache"): continue path = os.path.join(dir_path, file) if os.path.isfile(path): permission = oct(os.stat(path).st_mode) - self.assertEqual(permission, '0o100644') + self.assertEqual(permission, "0o100644") else: checkFilePermissions(path) - dir_path = os.path.join('..', 'splunklib') + dir_path = os.path.join("..", "splunklib") checkFilePermissions(dir_path) diff --git a/tests/testlib.py b/tests/testlib.py index e7c7b6a70..57ab1f2da 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -15,6 +15,7 @@ # under the License. """Shared unit test utilities.""" + import contextlib import os @@ -23,7 +24,7 @@ import sys # Run the test suite on the SDK without installing it. -sys.path.insert(0, '../') +sys.path.insert(0, "../") from time import sleep from datetime import datetime, timedelta @@ -35,11 +36,11 @@ from splunklib import client - logging.basicConfig( - filename='test.log', + filename="test.log", level=logging.DEBUG, - format="%(asctime)s:%(levelname)s:%(message)s") + format="%(asctime)s:%(levelname)s:%(message)s", +) class NoRestartRequiredError(Exception): @@ -51,15 +52,15 @@ class WaitTimedOutError(Exception): def to_bool(x): - if x == '1': + if x == "1": return True - if x == '0': + if x == "0": return False raise ValueError(f"Not a boolean value: {x}") def tmpname(): - name = 'delete-me-' + str(os.getpid()) + str(time.time()).replace('.', '-') + name = "delete-me-" + str(os.getpid()) + str(time.time()).replace(".", "-") return name @@ -79,8 +80,13 @@ class SDKTestCase(unittest.TestCase): restart_already_required = False installedApps = [] - def assertEventuallyTrue(self, predicate, timeout=30, pause_time=0.5, - timeout_message="Operation timed out."): + def assertEventuallyTrue( + self, + predicate, + timeout=30, + pause_time=0.5, + timeout_message="Operation timed out.", + ): assert pause_time < timeout start = datetime.now() diff = timedelta(seconds=timeout) @@ -157,7 +163,7 @@ def fake_splunk_version(self, version): self.service._splunk_version = original_version def install_app_from_collection(self, name): - collectionName = 'sdkappcollection' + collectionName = "sdkappcollection" if collectionName not in self.service.apps: raise ValueError("sdk-test-application not installed in splunkd") appPath = self.pathInApp(collectionName, ["build", name + ".tar"]) @@ -173,7 +179,7 @@ def install_app_from_collection(self, name): self.installedApps.append(name) def app_collection_installed(self): - collectionName = 'sdkappcollection' + collectionName = "sdkappcollection" return collectionName in self.service.apps def pathInApp(self, appName, pathComponents): @@ -201,7 +207,7 @@ def pathInApp(self, appName, pathComponents): :return: A string giving the path. """ - splunkHome = self.service.settings['SPLUNK_HOME'] + splunkHome = self.service.settings["SPLUNK_HOME"] if "\\" in splunkHome: # This clause must come first, since Windows machines may # have mixed \ and / in their paths. @@ -209,7 +215,9 @@ def pathInApp(self, appName, pathComponents): elif "/" in splunkHome: separator = "/" else: - raise ValueError("No separators in $SPLUNK_HOME. Can't determine what file separator to use.") + raise ValueError( + "No separators in $SPLUNK_HOME. Can't determine what file separator to use." + ) appPath = separator.join([splunkHome, "etc", "apps", appName] + pathComponents) return appPath @@ -225,7 +233,7 @@ def restartSplunk(self, timeout=240): @classmethod def setUpClass(cls): cls.opts = parse([], {}, ".env") - cls.opts.kwargs.update({'retries': 3}) + cls.opts.kwargs.update({"retries": 3}) # Before we start, make sure splunk doesn't need a restart. service = client.connect(**cls.opts.kwargs) if service.restart_required: @@ -233,14 +241,17 @@ def setUpClass(cls): def setUp(self): unittest.TestCase.setUp(self) - self.opts.kwargs.update({'retries': 3}) + self.opts.kwargs.update({"retries": 3}) self.service = client.connect(**self.opts.kwargs) # If Splunk is in a state requiring restart, go ahead # and restart. That way we'll be sane for the rest of # the test. if self.service.restart_required: self.restartSplunk() - logging.debug("Connected to splunkd version %s", '.'.join(str(x) for x in self.service.splunk_version)) + logging.debug( + "Connected to splunkd version %s", + ".".join(str(x) for x in self.service.splunk_version), + ) def tearDown(self): from splunklib.binding import HTTPError @@ -254,8 +265,10 @@ def tearDown(self): self.service.apps.delete(appName) wait(lambda: appName not in self.service.apps) except HTTPError as error: - if not (os.name == 'nt' and error.status == 500): + if not (os.name == "nt" and error.status == 500): raise - print(f'Ignoring failure to delete {appName} during tear down: {error}') + print( + f"Ignoring failure to delete {appName} during tear down: {error}" + ) if self.service.restart_required: self.clear_restart_message() diff --git a/utils/__init__.py b/utils/__init__.py index b6c455656..9711f0a25 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -24,62 +24,56 @@ def config(option, opt, value, parser): # Default Splunk cmdline rules RULES_SPLUNK = { - 'config': { - 'flags': ["--config"], - 'action': "callback", - 'callback': config, - 'type': "string", - 'nargs': "1", - 'help': "Load options from config file" + "config": { + "flags": ["--config"], + "action": "callback", + "callback": config, + "type": "string", + "nargs": "1", + "help": "Load options from config file", }, - 'scheme': { - 'flags': ["--scheme"], - 'default': "https", - 'help': "Scheme (default 'https')", + "scheme": { + "flags": ["--scheme"], + "default": "https", + "help": "Scheme (default 'https')", }, - 'host': { - 'flags': ["--host"], - 'default': "localhost", - 'help': "Host name (default 'localhost')" + "host": { + "flags": ["--host"], + "default": "localhost", + "help": "Host name (default 'localhost')", }, - 'port': { - 'flags': ["--port"], - 'default': "8089", - 'help': "Port number (default 8089)" + "port": { + "flags": ["--port"], + "default": "8089", + "help": "Port number (default 8089)", }, - 'app': { - 'flags': ["--app"], - 'help': "The app context (optional)" + "app": {"flags": ["--app"], "help": "The app context (optional)"}, + "owner": {"flags": ["--owner"], "help": "The user context (optional)"}, + "username": { + "flags": ["--username"], + "default": None, + "help": "Username to login with", }, - 'owner': { - 'flags': ["--owner"], - 'help': "The user context (optional)" + "password": { + "flags": ["--password"], + "default": None, + "help": "Password to login with", }, - 'username': { - 'flags': ["--username"], - 'default': None, - 'help': "Username to login with" + "version": { + "flags": ["--version"], + "default": None, + "help": "Ignore. Used by JavaScript SDK.", }, - 'password': { - 'flags': ["--password"], - 'default': None, - 'help': "Password to login with" + "splunkToken": { + "flags": ["--bearerToken"], + "default": None, + "help": "Bearer token for authentication", }, - 'version': { - 'flags': ["--version"], - 'default': None, - 'help': 'Ignore. Used by JavaScript SDK.' + "token": { + "flags": ["--sessionKey"], + "default": None, + "help": "Session key for authentication", }, - 'splunkToken': { - 'flags': ["--bearerToken"], - 'default': None, - 'help': 'Bearer token for authentication' - }, - 'token': { - 'flags': ["--sessionKey"], - 'default': None, - 'help': 'Session key for authentication' - } } FLAGS_SPLUNK = list(RULES_SPLUNK.keys()) @@ -88,14 +82,14 @@ def config(option, opt, value, parser): # value: dict, args: [(dict | list | str)*] def dslice(value, *args): """Returns a 'slice' of the given dictionary value containing only the - requested keys. The keys can be requested in a variety of ways, as an - arg list of keys, as a list of keys, or as a dict whose key(s) represent - the source keys and whose corresponding values represent the resulting - key(s) (enabling key rename), or any combination of the above.""" + requested keys. The keys can be requested in a variety of ways, as an + arg list of keys, as a list of keys, or as a dict whose key(s) represent + the source keys and whose corresponding values represent the resulting + key(s) (enabling key rename), or any combination of the above.""" result = {} for arg in args: if isinstance(arg, dict): - for k, v in (list(arg.items())): + for k, v in list(arg.items()): if k in value: result[v] = value[k] elif isinstance(arg, list): diff --git a/utils/cmdopts.py b/utils/cmdopts.py index 63cdfb1d4..cd0d08a61 100644 --- a/utils/cmdopts.py +++ b/utils/cmdopts.py @@ -19,19 +19,21 @@ import sys from dotenv import dotenv_values -__all__ = [ "error", "Parser", "cmdline" ] +__all__ = ["error", "Parser", "cmdline"] + # Print the given message to stderr, and optionally exit -def error(message, exitcode = None): +def error(message, exitcode=None): print(f"Error: {message}", file=sys.stderr) - if exitcode is not None: sys.exit(exitcode) + if exitcode is not None: + sys.exit(exitcode) class record(dict): def __getattr__(self, name): - try: - return self[name] - except KeyError: + try: + return self[name] + except KeyError: raise AttributeError(name) def __setattr__(self, name, value): @@ -39,11 +41,12 @@ def __setattr__(self, name, value): class Parser(OptionParser): - def __init__(self, rules = None, **kwargs): + def __init__(self, rules=None, **kwargs): OptionParser.__init__(self, **kwargs) self.dests = set({}) - self.result = record({ 'args': [], 'kwargs': record() }) - if rules is not None: self.init(rules) + self.result = record({"args": [], "kwargs": record()}) + if rules is not None: + self.init(rules) def init(self, rules): """Initialize the parser with the given command rules.""" @@ -54,14 +57,15 @@ def init(self, rules): # Assign defaults ourselves here, instead of in the option parser # itself in order to allow for multiple calls to parse (dont want # subsequent calls to override previous values with default vals). - if 'default' in rule: - self.result['kwargs'][dest] = rule['default'] + if "default" in rule: + self.result["kwargs"][dest] = rule["default"] - flags = rule['flags'] - kwargs = { 'action': rule.get('action', "store") } + flags = rule["flags"] + kwargs = {"action": rule.get("action", "store")} # NOTE: Don't provision the parser with defaults here, per above. - for key in ['callback', 'help', 'metavar', 'type']: - if key in rule: kwargs[key] = rule[key] + for key in ["callback", "help", "metavar", "type"]: + if key in rule: + kwargs[key] = rule[key] self.add_option(*flags, dest=dest, **kwargs) # Remember the dest vars that we see, so that we can merge results @@ -78,9 +82,10 @@ def load(self, filepath): # update result kwargs value with .env file data for key, value in filedata.items(): value = value.strip() - if len(value) == 0 or value is None: continue # Skip blank value + if len(value) == 0 or value is None: + continue # Skip blank value elif key in self.dests: - self.result['kwargs'][key] = value + self.result["kwargs"][key] = value else: raise NameError("No such option --" + key) @@ -88,24 +93,25 @@ def load(self, filepath): def loadif(self, filepath): """Load the given filepath if it exists, otherwise ignore.""" - if path.isfile(filepath): self.load(filepath) + if path.isfile(filepath): + self.load(filepath) return self def loadenv(self, filename): dir_path = path.dirname(path.realpath(__file__)) - filepath = path.join(dir_path, '..', filename) + filepath = path.join(dir_path, "..", filename) self.loadif(filepath) return self def parse(self, argv): """Parse the given argument vector.""" kwargs, args = self.parse_args(argv) - self.result['args'] += args + self.result["args"] += args # Annoying that parse_args doesn't just return a dict for dest in self.dests: value = getattr(kwargs, dest) if value is not None: - self.result['kwargs'][dest] = value + self.result["kwargs"][dest] = value return self def format_epilog(self, formatter): @@ -114,8 +120,8 @@ def format_epilog(self, formatter): def cmdline(argv, rules=None, config=None, **kwargs): """Simplified cmdopts interface that does not default any parsing rules - and that does not allow compounding calls to the parser.""" + and that does not allow compounding calls to the parser.""" parser = Parser(rules, **kwargs) - if config is not None: parser.loadenv(config) + if config is not None: + parser.loadenv(config) return parser.parse(argv).result -