Python 使用Matplotlib绘制可拖动的折线
效果图:
可以拖曲线上的点调整, 也可以拖旁边的sliderbar调整.
代码如下:
import matplotlib.animation as animation from matplotlib.widgets import Slider, Button import pandas as pd import matplotlib as mpl from matplotlib import pyplot as plt import scipy.interpolate as inter import numpy as np func = lambda x: np.zeros_like(x) def load_cache_weight(cache_file): import yaml global yvals with open(cache_file,'r') as f: line = f.readline() d = dict(yaml.safe_load(line)) keys = d.keys() for i, key in enumerate(keys): yvals[i] = d[key] # user input config N = 30 st = pd.to_datetime('20230414') cache_file = None cache_file = './tmp/saved_weight_2.json' save_file = './tmp/saved_weight_2.json' #get a list of points to fit a spline to as well xmin = 1 xmax = N+1 x = np.linspace(xmin,xmax,N) #spline fit yvals = func(x) if cache_file is not None: load_cache_weight(cache_file) spline = inter.InterpolatedUnivariateSpline (x, yvals) #figure.subplot.right mpl.rcParams['figure.subplot.left'] = 0.1 mpl.rcParams['figure.subplot.right'] = 0.8 #set up a plot fig,axes = plt.subplots(1,1,figsize=(16,5),sharex=True) ax1 = axes pind = None #active point epsilon = 5 #max pixel distance def update(val): global yvals global spline # update curve for i in np.arange(N): yvals[i] = sliders[i].val l.set_ydata(yvals) spline = inter.InterpolatedUnivariateSpline(x, yvals) m.set_ydata(spline(X)) # redraw canvas while idle fig.canvas.draw_idle() def reset(event): global yvals global spline #reset the values yvals = func(x) if cache_file is not None: load_cache_weight(cache_file) for i in np.arange(N): sliders[i].reset() spline = inter.InterpolatedUnivariateSpline(x, yvals) l.set_ydata(yvals) m.set_ydata(spline(X)) # redraw canvas while idle fig.canvas.draw_idle() def save_p(event): global yvals global datelst global save_file r = dict(zip(map(lambda x: x.strftime('%Y%m%d'),datelst),yvals)) print(r) if save_file is not None: with open(save_file,'w') as f: import json json.dump(r,f) def button_press_callback(event): 'whenever a mouse button is pressed' global pind if event.inaxes is None: return if event.button != 1: return #print(pind) pind = get_ind_under_point(event) def button_release_callback(event): 'whenever a mouse button is released' global pind if event.button != 1: return pind = None def get_ind_under_point(event): 'get the index of the vertex under point if within epsilon tolerance' # display coords #print('display x is: {0}; display y is: {1}'.format(event.x,event.y)) t = ax1.transData.inverted() tinv = ax1.transData xy = t.transform([event.x,event.y]) #print('data x is: {0}; data y is: {1}'.format(xy[0],xy[1])) xr = np.reshape(x,(np.shape(x)[0],1)) yr = np.reshape(yvals,(np.shape(yvals)[0],1)) xy_vals = np.append(xr,yr,1) xyt = tinv.transform(xy_vals) xt, yt = xyt[:, 0], xyt[:, 1] d = np.hypot(xt - event.x, yt - event.y) indseq, = np.nonzero(d == d.min()) ind = indseq[0] #print(d[ind]) if d[ind] >= epsilon: ind = None #print(ind) return ind def motion_notify_callback(event): 'on mouse movement' global yvals if pind is None: return if event.inaxes is None: return if event.button != 1: return #update yvals #print('motion x: {0}; y: {1}'.format(event.xdata,event.ydata)) yvals[pind] = np.clip(event.ydata,-1,1) # update curve via sliders and draw sliders[pind].set_val(yvals[pind]) fig.canvas.draw_idle() ############################ ed = st+pd.Timedelta(days=N-1) datelst = pd.date_range(st,ed) # ax1.plot () ########################### X = np.arange(0,xmax+1,0.1) ax1.plot (X, func(X), 'k--', label='original') l, = ax1.plot (x,yvals,color='k',linestyle='none',marker='o',markersize=8) m, = ax1.plot (X, spline(X), 'r-', label='spline') ax1.set_yscale('linear') ax1.set_xlim(0, 32) ax1.set_ylim(-1.05,1.05) ax1.set_xlabel('dt') ax1.set_ylabel('p') ax1.grid(True) ax1.yaxis.grid(True,which='minor',linestyle='--') ax1.legend(loc=2,prop={'size':8}) sliders = [] for i in np.arange(N): axamp = plt.axes([0.84, 0.95-(i*0.03), 0.12, 0.02]) # Slider date_i = datelst[i] mth = date_i.month day = date_i.day s = Slider(axamp, '{}/{}'.format(mth,day), -1, 1, valinit=yvals[i]) sliders.append(s) for i in np.arange(N): #samp.on_changed(update_slider) sliders[i].on_changed(update) axres = plt.axes([0.84, 0.95-((N)*0.03), 0.06, 0.02]) bres = Button(axres, 'Reset') bres.on_clicked(reset) axres = plt.axes([0.84+0.08, 0.95-((N)*0.03), 0.06, 0.02]) bres2 = Button(axres, 'Save') bres2.on_clicked(save_p) fig.canvas.mpl_connect('button_press_event', button_press_callback) fig.canvas.mpl_connect('button_release_event', button_release_callback) fig.canvas.mpl_connect('motion_notify_event', motion_notify_callback) plt.show()
标签:ax1,Python,global,Matplotlib,折线,file,np,yvals,event From: https://www.cnblogs.com/JiangOil/p/17321432.html