class WidgetSimulFocus(widgets.HBox):
def __init__(self):
super(WidgetSimulFocus, self).__init__()
focus_label = widgets.Label(value='Focal distance (cm):')
self.focus = widgets.FloatSlider(value=10, min=1,max=50,step=0.1,
disabled=False, continuous_update=False, orientation='horizontal',
readout=True, readout_format='.1f')
defocus_label = widgets.Label(value='Defocus (µm):')
self.defocus = widgets.FloatSlider(value=0, min=-2000,max=2000,step=50,
disabled=False, continuous_update=False, orientation='horizontal',
readout=True, readout_format='.1f')
self.display_type = widgets.RadioButtons(options=['RGBA', 'Amplitude', 'Phase'],
value='Amplitude', orientation='horizontal', disabled=False)
aperture_label = widgets.Label(value='Aperture (µm):')
self.aperture = widgets.FloatSlider(value=200, min=40,max=500,step=20,
disabled=False, continuous_update=False, orientation='horizontal',
readout=True, readout_format='.1f')
self.aperture_type = widgets.RadioButtons(options=['Circle', 'Square'],
value='Circle', orientation='horizontal', disabled=False)
pixel_label = widgets.Label(value='Pixel size @aperture (µm):')
self.pixel = widgets.FloatSlider(value=2, min=0.2,max=5,step=0.1,
disabled=False, continuous_update=False, orientation='horizontal',
readout=True, readout_format='.1f')
wsize_label = widgets.Label(value='Array size:')
self.wsize = widgets.FloatLogSlider(value=512,base=2, min=7,max=12,step=1,
disabled=False, continuous_update=False, orientation='horizontal',
readout=True, readout_format='.0f')
nrj_label = widgets.Label(value='X-ray energy (keV):')
self.nrj_kev = widgets.FloatSlider(value=10, min=1,max=40,step=0.5,
disabled=False, continuous_update=False, orientation='horizontal',
readout=True, readout_format='.1f')
display_location_label = widgets.Label(value='Plot wavefront at:')
self.display_location = widgets.RadioButtons(options=['aperture (before propagation)', 'focus/defocus'],
value='focus/defocus', orientation='horizontal', disabled=False)
vbox = widgets.VBox([focus_label, self.focus, defocus_label, self.defocus,self.display_type,
aperture_label, self.aperture, self.aperture_type,
pixel_label, self.pixel, wsize_label, self.wsize, nrj_label, self.nrj_kev,
display_location_label, self.display_location])
with plt.ioff():
self.fig = plt.figure(figsize=(12,6))
self.fig.canvas.header_visible=False # Hide fig num
self.focus.observe(self.plot)
self.defocus.observe(self.plot)
self.display_type.observe(self.plot)
self.aperture.observe(self.plot)
self.aperture_type.observe(self.plot)
self.display_location.observe(self.plot)
self.pixel.observe(self.init_wavefront)
self.wsize.observe(self.init_wavefront)
self.nrj_kev.observe(self.init_wavefront)
self.children = [self.fig.canvas, vbox]
self.last_plot_params = None
self.init_wavefront(plot=False)
def init_wavefront(self, plot=False):
n = int(self.wsize.value)
pix = self.pixel.value * 1e-6
wav = 12.3984e-10 / self.nrj_kev.value
self.w = Wavefront(d=np.ones((n, n), dtype=np.complex64), pixel_size=pix, wavelength=wav)
if plot:
self.plot()
def plot(self,k=None, force_plot=False):
if False if k is None else k['name'] != 'value':
return
w = self.w
wav = 12.3984e-10 / self.nrj_kev.value
n = int(self.wsize.value)
pix = self.pixel.value * 1e-6
plot_params=[n,self.focus.value, self.defocus.value, self.aperture.value,
self.aperture_type.value, self.pixel.value, self.nrj_kev.value, self.display_type.value,
self.display_location.value]
#if plot_params == self.last_plot_params and not force_plot:
# return
print(plot_params)
w.set(np.ones((n,n), dtype=np.complex64))
w.z = 0
w.pixel_size = pix
if self.aperture_type.value == 'Square':
w = RectangularMask(width=self.aperture.value*1e-6, height=self.aperture.value*1e-6) * w
else:
w = CircularMask(radius=self.aperture.value*1e-6/2) * w
if 'focus' in self.display_location.value:
w = PropagateFarField(self.focus.value*1e-2, forward=False) * w
w = PropagateNearField(self.defocus.value*1e-6) * w
tit = "f=%6.2fcm defocus=%5.0fµm" % (self.focus.value, self.defocus.value)
else:
tit = "Aperture: %s, size=%4.0fµm" % (self.aperture_type.value, self.aperture.value)
if self.display_type.value =='RGBA':
w = ImshowRGBA(title=tit, fig_num=self.fig.number, colorwheel=False) * w
elif self.display_type.value =='Amplitude':
w = ImshowAbs(title=tit, fig_num=self.fig.number) * w
elif self.display_type.value =='Phase':
w = ImshowAngle(title=tit, fig_num=self.fig.number) * w
self.fig.canvas.draw()
self.fig.canvas.flush_events()
self.last_plot_params= plot_params
w = WidgetSimulFocus()
w