Averaging detector data with Dask

We often want to average large detector data across trains, keeping the pulses within each train separate, so we have an average image for pulse 0, another for pulse 1, etc.

This data may be too big to load into memory at once, but using Dask we can work with it like a numpy array. Dask takes care of splitting the job up into smaller pieces and assembling the result.

[1]:
from karabo_data import open_run

import dask.array as da
from dask.distributed import Client, progress
from dask_jobqueue import SLURMCluster
import numpy as np

First, we use Dask-Jobqueue to talk to the Maxwell cluster.

[2]:
partition = 'exfel'  # For EuXFEL staff
#partition = 'upex'   # For users

cluster = SLURMCluster(
    queue=partition,
    # Resources per SLURM job (per node, the way SLURM is configured on Maxwell)
    # processes=16 runs 16 Dask workers in a job, so each worker has 1 core & 16 GB RAM.
    processes=16, cores=16, memory='256GB',
)

# Get a notbook widget showing the cluster state
cluster
[3]:
# Submit 2 SLURM jobs, for 32 Dask workers
cluster.scale(32)

If the cluster is busy, you might need to wait a while for the jobs to start. The cluster widget above will update when they’re running.

Next, we’ll set Dask up to use those workers:

[4]:
client = Client(cluster)
print("Created dask client:", client)
Created dask client: <Client: scheduler='tcp://131.169.193.102:44986' processes=32 cores=32>

Now Dask is ready, let’s open the run we’re going to operate on:

[5]:
run = open_run(proposal=2212, run=103)
run.info()
# of trains:    3299
Duration:       0:05:29.800000
First train ID: 517617973
Last train ID:  517621271

16 detector modules (SCS_DET_DSSC1M-1)
  e.g. module SCS_DET_DSSC1M-1 0 : 128 x 512 pixels
  75 frames per train, 247425 total frames

3 instrument sources (excluding detectors):
  - SA3_XTD10_XGM/XGM/DOOCS:output
  - SCS_BLU_XGM/XGM/DOOCS:output
  - SCS_UTC1_ADQ/ADC/1:network

20 control sources:
  - P_GATT
  - SA3_XTD10_MONO/ENC/GRATING_AX
  - SA3_XTD10_MONO/MDL/PHOTON_ENERGY
  - SA3_XTD10_MONO/MOTOR/GRATINGS_X
  - SA3_XTD10_MONO/MOTOR/GRATING_AX
  - SA3_XTD10_MONO/MOTOR/HE_PM_X
  - SA3_XTD10_MONO/MOTOR/LE_PM_X
  - SA3_XTD10_VAC/DCTRL/AR_MODE_OK
  - SA3_XTD10_VAC/DCTRL/D12_APERT_IN_OK
  - SA3_XTD10_VAC/DCTRL/D6_APERT_IN_OK
  - SA3_XTD10_VAC/DCTRL/N2_MODE_OK
  - SA3_XTD10_VAC/GAUGE/G30470D_IN
  - SA3_XTD10_VAC/GAUGE/G30480D_IN
  - SA3_XTD10_VAC/GAUGE/G30490D_IN
  - SA3_XTD10_VAC/GAUGE/G30510C
  - SA3_XTD10_XGM/XGM/DOOCS
  - SCS_BLU_XGM/XGM/DOOCS
  - SCS_RR_UTC/MDL/BUNCH_DECODER
  - SCS_RR_UTC/TSYS/TIMESERVER
  - SCS_UTC1_ADQ/ADC/1

We’re working with data from the DSSC detector. In this run, it’s recording 75 frames for each train - this is part of the info above.

Now, we’ll define how we’re going to average over trains for each module:

[6]:
def average_module(modno, run, pulses_per_train=75):
    source = f'SCS_DET_DSSC1M-1/DET/{modno}CH0:xtdf'
    counts = run.get_data_counts(source, 'image.data')

    arr = run.get_dask_array(source, 'image.data')
    # Make a new dimension for trains
    arr_trains = arr.reshape(-1, pulses_per_train, 128, 512)
    if modno == 0:
        print("array shape:", arr.shape)  # frames, dummy, 128, 512
        print("Reshaped to:", arr_trains.shape)

    return arr_trains.mean(axis=0, dtype=np.float32)
[7]:
mod_averages = [
    average_module(i, run, pulses_per_train=75)
    for i in range(16)
]

mod_averages
array shape: (247425, 1, 128, 512)
Reshaped to: (3299, 75, 128, 512)
[7]:
[dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>,
 dask.array<mean_agg-aggregate, shape=(75, 128, 512), dtype=float32, chunksize=(75, 128, 512)>]
[8]:
# Stack the averages into a single array
all_average = da.stack(mod_averages)
all_average
[8]:
Array Chunk
Bytes 314.57 MB 19.66 MB
Shape (16, 75, 128, 512) (1, 75, 128, 512)
Count 2560 Tasks 16 Chunks
Type float32 numpy.ndarray
16 1 512 128 75

Dask shows us what shape the result array will be, but so far, no real computation has happened. Now that we’ve defined what we want, let’s tell Dask to compute it.

This will take a minute or two. If you’re running it, scroll up to the Dask cluster widget and click the status link to see what it’s doing.

[9]:
%%time
all_average_arr = all_average.compute()  # Get a concrete numpy array for the result
CPU times: user 20.8 s, sys: 2.6 s, total: 23.4 s
Wall time: 1min 42s

all_average_arr is a regular numpy array with our results. Here are the values from the corner of module 0, frame 0:

[10]:
print(all_average_arr[0, 0, :5, :5])
[[48.822674 50.983025 44.953014 44.08245  45.056988]
 [45.8251   49.183388 46.39982  43.371628 47.53501 ]
 [51.03395  46.02243  44.92058  50.966656 42.918762]
 [43.190662 49.961502 44.23007  43.252197 47.663536]
 [48.844803 51.489845 50.45438  46.305546 47.51258 ]]

Please shut down the cluster (or scale it down to 0 workers) if you won’t be using it for a while. This releases the resources for other people.

[11]:
client.close()
cluster.close()