Skip to content

heatmap

Short Description

The sm.pl.heatmap function generates a comprehensive visualization of marker expression or other relevant features across various groups or clusters identified within spatial datasets. Through customizable clustering, normalization, and annotation features, it supports detailed exploratory data analysis and comparison across different conditions or phenotypes. This function effectively consolidates complex datasets into intuitive visual representations, enhancing the interpretability of high-dimensional data.

Function

heatmap(adata, groupBy, layer=None, subsetMarkers=None, subsetGroups=None, clusterRows=True, clusterColumns=True, standardScale=None, orderRow=None, orderColumn=None, showPrevalence=False, cmap='vlag', figsize=None, saveDir=None, fileName=None, verbose=True, **kwargs)

Parameters:

Name Type Description Default
adata AnnData

An AnnData object or path to an Anndata object containing the dataset to be visualized. It should have features as variables and observations as rows.

required
groupBy str

The key in adata.obs on which to group observations. Typically, this will be a clustering pr phenotype label like 'leiden' or 'phenotype'.

required
layer str

Specifies the layer of adata to use for the heatmap. If None, the .X attribute is used. If you want to plot the raw data use raw

None
subsetMarkers list of str

A list of marker genes or features to include in the heatmap. If None, all markers are used.

None
subsetGroups list of str

A list of group labels to include in the heatmap. Useful for focusing on specific clusters or conditions.

None
clusterRows bool

Whether to cluster rows (observations).

True
clusterColumns bool

Whether to cluster columns (features).

True
standardScale str

Determines if and how to normalize the data across rows or columns. Acceptable values are 'row', 'column', or None.

None
orderRow list of str

Specifies a custom order for the rows based on group labels.

None
orderColumn list of str

Specifies a custom order for the columns based on feature names.

None
showPrevalence bool

If True, adds a bar showing the prevalence of the feature across the groups.

False
cmap str

The colormap for the heatmap.

'vlag'
figsize tuple of float

The size of the figure to create. If None, the size is inferred.

None
saveDir str

Directory to save the generated heatmap. If None, the heatmap is not saved.

None
fileName str

Name of the file to save the heatmap. Relevant only if saveDir is not None.

None
verbose bool

If True, print additional information during execution.

True
**kwargs

Additional keyword arguments are passed to the underlying matplotlib plotting function.

{}

Returns:

Name Type Description
plot matplotlib

Returns a plot, if saveDir and fileName are provided, the plot is saved in the given directory.

Example
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Example 1: Basic usage with clustering and standard scale by column.

sm.pl.heatmap(adata, groupBy='leiden', standardScale='column')

# Example 2: Advanced usage with specified subset markers, custom grouping, and file saving.

subsetMarkers = ['ELANE', 'CD57', 'CD45', 'CD11B', 'SMA', 'CD16', 'ECAD']
subsetGroups = ['0', '1', '3', '6']
orderRow = ['6', '3', '0', '1']
orderColumn = ['SMA', 'CD16', 'ECAD', 'ELANE', 'CD57', 'CD45', 'CD11B']
saveDir = '/path/to/save'
fileName = 'custom_heatmap.pdf'

sm.pl.heatmap(adata, groupBy='leiden', subsetMarkers=subsetMarkers, subsetGroups=subsetGroups, clusterRows=False, clusterColumns=False, standardScale='column', orderRow=orderRow, orderColumn=orderColumn, showPrevalence=True, figsize=(10, 5), saveDir=saveDir, fileName=fileName, vmin=0, vmax=1)
Source code in scimap/plotting/heatmap.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def heatmap(
    adata,
    groupBy,
    layer=None,
    subsetMarkers=None,
    subsetGroups=None,
    clusterRows=True,
    clusterColumns=True,
    standardScale=None,
    orderRow=None,
    orderColumn=None,
    showPrevalence=False,
    cmap='vlag',
    figsize=None,
    saveDir=None,
    fileName=None,
    verbose=True,
    **kwargs,
):
    """

    Parameters:
        adata (AnnData):
            An AnnData object or `path` to an Anndata object containing the dataset to be visualized. It should have features as variables and observations as rows.

        groupBy (str):
            The key in `adata.obs` on which to group observations. Typically, this will be a clustering pr phenotype label like 'leiden' or 'phenotype'.

        layer (str, optional):
            Specifies the layer of `adata` to use for the heatmap. If None, the `.X` attribute is used. If you want to plot the raw data use `raw`

        subsetMarkers (list of str, optional):
            A list of marker genes or features to include in the heatmap. If None, all markers are used.

        subsetGroups (list of str, optional):
            A list of group labels to include in the heatmap. Useful for focusing on specific clusters or conditions.

        clusterRows (bool):
            Whether to cluster rows (observations).

        clusterColumns (bool):
            Whether to cluster columns (features).

        standardScale (str, optional):
            Determines if and how to normalize the data across rows or columns. Acceptable values are 'row', 'column', or None.

        orderRow (list of str, optional):
            Specifies a custom order for the rows based on group labels.

        orderColumn (list of str, optional):
            Specifies a custom order for the columns based on feature names.

        showPrevalence (bool):
            If True, adds a bar showing the prevalence of the feature across the groups.

        cmap (str):
            The colormap for the heatmap.

        figsize (tuple of float, optional):
            The size of the figure to create. If None, the size is inferred.

        saveDir (str, optional):
            Directory to save the generated heatmap. If None, the heatmap is not saved.

        fileName (str, optional):
            Name of the file to save the heatmap. Relevant only if `saveDir` is not None.

        verbose (bool):
            If True, print additional information during execution.

        **kwargs:
            Additional keyword arguments are passed to the underlying matplotlib plotting function.

    Returns:
        plot (matplotlib):
            Returns a plot, if `saveDir` and `fileName` are provided, the plot is saved in the given directory.

    Example:
            ```python

            # Example 1: Basic usage with clustering and standard scale by column.

            sm.pl.heatmap(adata, groupBy='leiden', standardScale='column')

            # Example 2: Advanced usage with specified subset markers, custom grouping, and file saving.

            subsetMarkers = ['ELANE', 'CD57', 'CD45', 'CD11B', 'SMA', 'CD16', 'ECAD']
            subsetGroups = ['0', '1', '3', '6']
            orderRow = ['6', '3', '0', '1']
            orderColumn = ['SMA', 'CD16', 'ECAD', 'ELANE', 'CD57', 'CD45', 'CD11B']
            saveDir = '/path/to/save'
            fileName = 'custom_heatmap.pdf'

            sm.pl.heatmap(adata, groupBy='leiden', subsetMarkers=subsetMarkers, subsetGroups=subsetGroups, clusterRows=False, clusterColumns=False, standardScale='column', orderRow=orderRow, orderColumn=orderColumn, showPrevalence=True, figsize=(10, 5), saveDir=saveDir, fileName=fileName, vmin=0, vmax=1)

            ```
    """

    # load adata
    if isinstance(adata, str):
        adata = ad.read_h5ad(adata)

    # check if the location is provided if the user wishes to save the image
    if (saveDir is None and fileName is not None) or (
        saveDir is not None and fileName is None
    ):
        raise ValueError(
            "Both 'saveDir' and 'fileName' must be provided together or not at all."
        )

    # subset data if user requests
    subsetadata = None  # intialize subsetted data
    if subsetGroups:
        subsetGroups = (
            [subsetGroups] if isinstance(subsetGroups, str) else subsetGroups
        )  # convert to list
        subsetadata = adata[adata.obs[groupBy].isin(subsetGroups)]
        # also identify the categories to be plotted
        categories = subsetadata.obs[groupBy].values
    else:
        # also identify the categories to be plotted
        categories = adata.obs[groupBy].values

    # subset the markers if user requests
    if subsetMarkers:
        subsetMarkers = (
            [subsetMarkers] if isinstance(subsetMarkers, str) else subsetMarkers
        )  # convert to list
        if subsetadata:
            # isolate the data
            if layer == 'raw':
                data = subsetadata[:, subsetMarkers].raw.X
            elif layer is None:
                data = subsetadata[:, subsetMarkers].X
            else:
                data = subsetadata[:, subsetMarkers].layers[layer]
        else:
            # isolate the data
            if layer == 'raw':
                data = adata[:, subsetMarkers].raw.X
            elif layer is None:
                data = adata[:, subsetMarkers].X
            else:
                data = adata[:, subsetMarkers].layers[layer]
    else:
        # take the whole data if the user does not subset anything
        if layer == 'raw':
            data = adata.raw.X
        elif layer is None:
            data = adata.X
        else:
            data = adata.layers[layer]

    # intialize the markers to be plotted
    if subsetMarkers is None:
        subsetMarkers = adata.var.index.tolist()

    # The actual plotting function
    def plot_category_heatmap_vectorized(
        data,
        marker_names,
        categories,
        clusterRows,
        clusterColumns,
        standardScale,
        orderRow,
        orderColumn,
        showPrevalence,
        cmap,
        figsize,
        saveDir,
        fileName,
        **kwargs,
    ):
        # Validate clustering and ordering options
        if (clusterRows or clusterColumns) and (
            orderRow is not None or orderColumn is not None
        ):
            raise ValueError(
                "Cannot use clustering and manual ordering together. Please choose one or the other."
            )

        if standardScale not in [None, 'row', 'column']:
            raise ValueError("standardScale must be 'row', 'column', or None.")

        # Convert marker_names to list if it's a pandas Index
        # if isinstance(marker_names, pd.Index):
        #    marker_names = marker_names.tolist()

        # Data preprocessing
        sorted_indices = np.argsort(categories)
        data = data[sorted_indices, :]
        categories = categories[sorted_indices]
        unique_categories, category_counts = np.unique(categories, return_counts=True)

        # Compute mean values for each category
        mean_data = np.array(
            [
                np.mean(data[categories == category, :], axis=0)
                for category in unique_categories
            ]
        )

        # Apply standard scaling if specified
        if standardScale == 'row':
            scaler = StandardScaler()
            mean_data = scaler.fit_transform(mean_data)
        elif standardScale == 'column':
            scaler = StandardScaler()
            mean_data = scaler.fit_transform(mean_data.T).T

        # Apply manual ordering if specified
        if orderRow:
            # Ensure orderRow is a list
            if isinstance(orderRow, pd.Index):
                orderRow = orderRow.tolist()
            row_order = [unique_categories.tolist().index(r) for r in orderRow]
            mean_data = mean_data[row_order, :]
            unique_categories = [unique_categories[i] for i in row_order]
            category_counts = [category_counts[i] for i in row_order]

        if orderColumn:
            # Ensure orderColumn is a list
            if isinstance(orderColumn, pd.Index):
                orderColumn = orderColumn.tolist()
            col_order = [marker_names.index(c) for c in orderColumn]
            mean_data = mean_data[:, col_order]
            marker_names = [marker_names[i] for i in col_order]

            # Clustering
        if clusterRows:
            # Perform hierarchical clustering
            row_linkage = linkage(pdist(mean_data), method='average')
            # Reorder data according to the clustering
            row_order = dendrogram(row_linkage, no_plot=True)['leaves']
            mean_data = mean_data[row_order, :]
            unique_categories = unique_categories[row_order]
            category_counts = category_counts[row_order]

        if clusterColumns:
            # Perform hierarchical clustering
            col_linkage = linkage(pdist(mean_data.T), method='average')
            # Reorder data according to the clustering
            col_order = dendrogram(col_linkage, no_plot=True)['leaves']
            mean_data = mean_data[:, col_order]
            marker_names = [marker_names[i] for i in col_order]

        # Plotting
        # Dynamic figsize calculation
        if figsize is None:
            base_size = 0.5  # Base size for each cell in inches
            figsize_width = max(10, len(marker_names) * base_size)
            figsize_height = max(8, len(unique_categories) * base_size)
            figsize = (figsize_width, figsize_height)

        fig, ax = plt.subplots(figsize=figsize, constrained_layout=True)

        # Heatmap
        # Extract vmin and vmax from kwargs if present, else default to min and max of mean_data
        vmin = kwargs.pop('vmin', np.min(mean_data))
        vmax = kwargs.pop('vmax', np.max(mean_data))

        # Create the Normalize instance with vmin and vmax
        norm = Normalize(vmin=vmin, vmax=vmax)

        c = ax.imshow(mean_data, aspect='auto', cmap=cmap, norm=norm, **kwargs)

        # Prevalence text
        if showPrevalence:
            # Calculate text offset from the last column of the heatmap
            text_offset = (
                mean_data.shape[1] * 0.001
            )  # Small offset from the right edge of the heatmap

            for index, count in enumerate(category_counts):
                # Position text immediately to the right of the heatmap
                ax.text(
                    mean_data.shape[1] + text_offset,
                    index,
                    f"n={count}",
                    va='center',
                    ha='left',
                )

        # Setting the tick labels
        ax.set_xticks(np.arange(mean_data.shape[1]))
        ax.set_xticklabels(marker_names, rotation=90, ha="right")
        ax.set_yticks(np.arange(mean_data.shape[0]))
        ax.set_yticklabels(unique_categories)

        # Move the colorbar to the top left corner
        # cbar_ax = fig.add_axes([0.125, 0.92, 0.2, 0.02])  # x, y, width, height
        cbar_ax = ax.inset_axes([-0.5, -1.5, 4, 0.5], transform=ax.transData)
        cbar = plt.colorbar(c, cax=cbar_ax, orientation='horizontal')
        cbar_ax.xaxis.set_ticks_position('top')
        cbar_ax.xaxis.set_label_position('top')
        cbar.set_label('Mean expression in group')

        ax.set_xlabel('Markers')
        ax.set_ylabel('Categories')

        # plt.tight_layout(rect=[0, 0, 0.9, 0.9]) # Adjust the layout

        # Saving the figure if saveDir and fileName are provided
        if saveDir:
            if not os.path.exists(saveDir):
                os.makedirs(saveDir)
            full_path = os.path.join(saveDir, fileName)
            plt.savefig(full_path, dpi=300)
            plt.close(fig)
            print(f"Saved heatmap to {full_path}")
        else:
            plt.show()

    # call the plotting function
    plot_category_heatmap_vectorized(
        data=data,
        marker_names=subsetMarkers,
        categories=categories,
        clusterRows=clusterRows,
        clusterColumns=clusterColumns,
        standardScale=standardScale,
        orderRow=orderRow,
        orderColumn=orderColumn,
        showPrevalence=showPrevalence,
        cmap=cmap,
        figsize=figsize,
        saveDir=saveDir,
        fileName=fileName,
        **kwargs,
    )