diff --git a/src/mplfinance/_version.py b/src/mplfinance/_version.py index 44439d7d..ff6c7f0b 100644 --- a/src/mplfinance/_version.py +++ b/src/mplfinance/_version.py @@ -1,4 +1,4 @@ -version_info = (0, 12, 9, 'beta', 2) +version_info = (0, 12, 9, 'beta', 3) _specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''} diff --git a/src/mplfinance/plotting.py b/src/mplfinance/plotting.py index b3d4f341..fbde618a 100644 --- a/src/mplfinance/plotting.py +++ b/src/mplfinance/plotting.py @@ -119,7 +119,17 @@ def _valid_plot_kwargs(): 'mav' : { 'Default' : None, 'Description' : 'Moving Average window size(s); (int or tuple of ints)', 'Validator' : _mav_validator }, + + 'ema' : { 'Default' : None, + 'Description' : 'Exponential Moving Average window size(s); (int or tuple of ints)', + 'Validator' : _mav_validator }, + 'mavcolors' : { 'Default' : None, + 'Description' : 'color cycle for moving averages (list or tuple of colors)'+ + '(overrides mpf style mavcolors).', + 'Validator' : lambda value: isinstance(value,(list,tuple)) and + all([mcolors.is_color_like(v) for v in value]) }, + 'renko_params' : { 'Default' : dict(), 'Description' : 'dict of renko parameters; call `mpf.kwarg_help("renko_params")`', 'Validator' : lambda value: isinstance(value,dict) }, @@ -450,6 +460,13 @@ def plot( data, **kwargs ): else: raise TypeError('style should be a `dict`; why is it not?') + if config['mavcolors'] is not None: + config['_ma_color_cycle'] = cycle(config['mavcolors']) + elif style['mavcolors'] is not None: + config['_ma_color_cycle'] = cycle(style['mavcolors']) + else: + config['_ma_color_cycle'] = None + if not external_axes_mode: fig = plt.figure() _adjust_figsize(fig,config) @@ -528,8 +545,10 @@ def plot( data, **kwargs ): if ptype in VALID_PMOVE_TYPES: mavprices = _plot_mav(axA1,config,xdates,pmove_avgvals) + emaprices = _plot_ema(axA1, config, xdates, pmove_avgvals) else: mavprices = _plot_mav(axA1,config,xdates,closes) + emaprices = _plot_ema(axA1, config, xdates, closes) avg_dist_between_points = (xdates[-1] - xdates[0]) / float(len(xdates)) if not config['tight_layout']: @@ -595,6 +614,13 @@ def plot( data, **kwargs ): else: for jj in range(0,len(mav)): retdict['mav' + str(mav[jj])] = mavprices[jj] + if config['ema'] is not None: + ema = config['ema'] + if len(ema) != len(emaprices): + warnings.warn('len(ema)='+str(len(ema))+' BUT len(emaprices)='+str(len(emaprices))) + else: + for jj in range(0, len(ema)): + retdict['ema' + str(ema[jj])] = emaprices[jj] retdict['minx'] = minx retdict['maxx'] = maxx retdict['miny'] = miny @@ -1129,10 +1155,7 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None): if len(mavgs) > 7: mavgs = mavgs[0:7] # take at most 7 - if style['mavcolors'] is not None: - mavc = cycle(style['mavcolors']) - else: - mavc = None + mavc = config['_ma_color_cycle'] for idx,mav in enumerate(mavgs): mean = pd.Series(prices).rolling(mav).mean() @@ -1147,6 +1170,42 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None): mavp_list.append(mavprices) return mavp_list + +def _plot_ema(ax,config,xdates,prices,apmav=None,apwidth=None): + '''ema: exponential moving average''' + style = config['style'] + if apmav is not None: + mavgs = apmav + else: + mavgs = config['ema'] + mavp_list = [] + if mavgs is not None: + shift = None + if isinstance(mavgs,dict): + shift = mavgs['shift'] + mavgs = mavgs['period'] + if isinstance(mavgs,int): + mavgs = mavgs, # convert to tuple + if len(mavgs) > 7: + mavgs = mavgs[0:7] # take at most 7 + + mavc = config['_ma_color_cycle'] + + for idx,mav in enumerate(mavgs): + # mean = pd.Series(prices).rolling(mav).mean() + mean = pd.Series(prices).ewm(span=mav,adjust=False).mean() + if shift is not None: + mean = mean.shift(periods=shift[idx]) + emaprices = mean.values + lw = config['_width_config']['line_width'] + if mavc: + ax.plot(xdates, emaprices, linewidth=lw, color=next(mavc)) + else: + ax.plot(xdates, emaprices, linewidth=lw) + mavp_list.append(emaprices) + return mavp_list + + def _auto_secondary_y( panels, panid, ylo, yhi ): # If mag(nitude) for this panel is not yet set, then set it # here, as this is the first ydata to be plotted on this panel: diff --git a/tests/reference_images/ema01.png b/tests/reference_images/ema01.png new file mode 100644 index 00000000..e21f3921 Binary files /dev/null and b/tests/reference_images/ema01.png differ diff --git a/tests/reference_images/ema02.png b/tests/reference_images/ema02.png new file mode 100644 index 00000000..5b8807f5 Binary files /dev/null and b/tests/reference_images/ema02.png differ diff --git a/tests/reference_images/ema03.png b/tests/reference_images/ema03.png new file mode 100644 index 00000000..561b02d2 Binary files /dev/null and b/tests/reference_images/ema03.png differ diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 00000000..f691e4fd --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,124 @@ +import os +import os.path +import glob +import mplfinance as mpf +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.testing.compare import compare_images + +print('mpf.__version__ =',mpf.__version__) # for the record +print('mpf.__file__ =',mpf.__file__) # for the record +print("plt.rcParams['backend'] =",plt.rcParams['backend']) # for the record + +base='ema' +tdir = os.path.join('tests','test_images') +refd = os.path.join('tests','reference_images') + +globpattern = os.path.join(tdir,base+'*.png') +oldtestfiles = glob.glob(globpattern) +for fn in oldtestfiles: + try: + os.remove(fn) + except: + print('Error removing file "'+fn+'"') + +IMGCOMP_TOLERANCE = 10.0 # this works fine for linux +# IMGCOMP_TOLERANCE = 11.0 # required for a windows pass. (really 10.25 may do it). + +_df = pd.DataFrame() +def get_ema_data(): + global _df + if len(_df) == 0: + _df = pd.read_csv('./examples/data/yahoofinance-GOOG-20040819-20180120.csv', + index_col='Date',parse_dates=True) + return _df + + +def create_ema_image(tname): + + df = get_ema_data() + df = df[-50:] # show last 50 data points only + + ema25 = df['Close'].ewm(span=25.0, adjust=False).mean() + mav25 = df['Close'].rolling(window=25).mean() + + ap = [ + mpf.make_addplot(df, panel=1, type='ohlc', color='c', + ylabel='mpf mav', mav=25, secondary_y=False), + mpf.make_addplot(ema25, panel=2, type='line', width=2, color='c', + ylabel='calculated', secondary_y=False), + mpf.make_addplot(mav25, panel=2, type='line', width=2, color='blue', + ylabel='calculated', secondary_y=False) + ] + + # plot and save in `tname` path + mpf.plot(df, ylabel="mpf ema", type='ohlc', + ema=25, addplot=ap, panel_ratios=(1, 1), savefig=tname + ) + + +def test_ema01(): + + fname = base+'01.png' + tname = os.path.join(tdir,fname) + rname = os.path.join(refd,fname) + + create_ema_image(tname) + + tsize = os.path.getsize(tname) + print(glob.glob(tname),'[',tsize,'bytes',']') + + rsize = os.path.getsize(rname) + print(glob.glob(rname),'[',rsize,'bytes',']') + + result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE) + if result is not None: + print('result=',result) + assert result is None + +def test_ema02(): + fname = base+'02.png' + tname = os.path.join(tdir,fname) + rname = os.path.join(refd,fname) + + df = get_ema_data() + df = df[-125:-35] + + mpf.plot(df, type='candle', ema=(5,15,25), mav=(5,15,25), savefig=tname) + + tsize = os.path.getsize(tname) + print(glob.glob(tname),'[',tsize,'bytes',']') + + rsize = os.path.getsize(rname) + print(glob.glob(rname),'[',rsize,'bytes',']') + + result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE) + if result is not None: + print('result=',result) + assert result is None + +def test_ema03(): + fname = base+'03.png' + tname = os.path.join(tdir,fname) + rname = os.path.join(refd,fname) + + df = get_ema_data() + df = df[-125:-35] + + mac = ['red','orange','yellow','green','blue','purple'] + + mpf.plot(df, type='candle', ema=(5,10,15,25), mav=(5,15,25), + mavcolors=mac, savefig=tname) + + + tsize = os.path.getsize(tname) + print(glob.glob(tname),'[',tsize,'bytes',']') + + rsize = os.path.getsize(rname) + print(glob.glob(rname),'[',rsize,'bytes',']') + + result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE) + if result is not None: + print('result=',result) + assert result is None +