Skip to content

distPlot

Short Description

The sm.pl.distPlot function is used to create distribution plots of marker intensity data.

Function

distPlot(adata, layer=None, markers=None, subset=None, imageid='imageid', vline=None, plotGrid=True, ncols=None, color=None, xticks=None, figsize=(5, 5), fontsize=None, dpi=200, saveDir=None, fileName='scimapDistPlot.png')

Parameters:

Name Type Description Default
adata AnnData

Annotated data object.

required
layer str

Layer of data to plot.

None
markers list

List of marker genes to plot.

None
subset list or None

imageid of a single or multiple images to be subsetted for plotting purposes.

None
imageid str

The column name in spatial feature table that contains the image ID for each cell.

'imageid'
vline float or auto

The x-coordinate of the vertical line to plot. If set to None, a vertical line is not plotted. Use 'auto' to draw a vline at the center point.

None
plotGrid bool

Whether to plot each marker in it's own sub plot. If False and multiple markers are passed in via markers, all distributions will be plotted within a single plot.

True
ncols int

The number of columns in the final plot when multiple variables are plotted.

None
color str

Color of the distribution plot.

None
xticks list of float

Custom x-axis tick values.

None
figsize tuple

Figure size. Defaults to (5, 5).

(5, 5)
fontsize int

The size of the font of the axis labels.

None
dpi int

The DPI of the figure. Use this to control the point size. Lower the dpi, larger the point size.

200
saveDir str

The directory to save the output plot.

None
fileName str

The name of the output file. Use desired file format as suffix (e.g. .png or .pdf).

'scimapDistPlot.png'

Returns:

Name Type Description
Plot image

If outputDir is provided the plot will saved within the provided outputDir.

Example
1
2
3
4
5
sm.pl.distPlot(adata,
             layer=None,
             markers=['CD45','CD3D','CD20'],
             plotGrid=True,
             ncols=5)
Source code in scimap/plotting/distPlot.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def distPlot(
    adata,
    layer=None,
    markers=None,
    subset=None,
    imageid='imageid',
    vline=None,
    plotGrid=True,
    ncols=None,
    color=None,
    xticks=None,
    figsize=(5, 5),
    fontsize=None,
    dpi=200,
    saveDir=None,
    fileName='scimapDistPlot.png',
):
    """
    Parameters:
        adata (anndata.AnnData):
            Annotated data object.

        layer (str, optional):
            Layer of data to plot.

        markers (list, optional):
            List of marker genes to plot.

        subset (list or None, optional):
            `imageid` of a single or multiple images to be subsetted for plotting purposes.

        imageid (str, optional):
            The column name in `spatial feature table` that contains the image ID
            for each cell.

        vline (float or 'auto', optional):
            The x-coordinate of the vertical line to plot. If set to `None`, a vertical line is not plotted.
            Use 'auto' to draw a vline at the center point.

        plotGrid (bool, optional):
            Whether to plot each marker in it's own sub plot. If `False` and multiple markers
            are passed in via `markers`, all distributions will be plotted within a single plot.

        ncols (int, optional):
            The number of columns in the final plot when multiple variables are plotted.

        color (str, optional):
            Color of the distribution plot.

        xticks (list of float, optional):
            Custom x-axis tick values.

        figsize (tuple, optional):
            Figure size. Defaults to (5, 5).

        fontsize (int, optional):
            The size of the font of the axis labels.

        dpi (int, optional):
            The DPI of the figure. Use this to control the point size. Lower the dpi, larger the point size.

        saveDir (str, optional):
            The directory to save the output plot.

        fileName (str, optional):
            The name of the output file. Use desired file format as suffix (e.g. `.png` or `.pdf`).

    Returns:
        Plot (image):
            If `outputDir` is provided the plot will saved within the provided outputDir.

    Example:
            ```python

            sm.pl.distPlot(adata,
                         layer=None,
                         markers=['CD45','CD3D','CD20'],
                         plotGrid=True,
                         ncols=5)
            ```

    """

    # testing
    # layers=None; markers=None; plotGrid=True; ncols=None; color=None; figsize=(10, 10); fontsize=None; subset=None; imageid='imageid'; xticks=None; dpi=200; outputDir=None;
    # outputFileName='distPlot.png'
    # color = {'markerA': '#000000', 'markerB': '#FF0000'}
    # outputDir = r"C:\Users\aj\Downloads"

    # subset data if neede
    if subset is not None:
        if isinstance(subset, str):
            subset = [subset]
        if layer == 'raw':
            bdata = adata.copy()
            bdata.X = adata.raw.X
            bdata = bdata[bdata.obs[imageid].isin(subset)]
        else:
            bdata = adata.copy()
            bdata = bdata[bdata.obs[imageid].isin(subset)]
    else:
        bdata = adata.copy()

    # isolate the data
    if layer is None:
        data = pd.DataFrame(bdata.X, index=bdata.obs.index, columns=bdata.var.index)
    elif layer == 'raw':
        data = pd.DataFrame(bdata.raw.X, index=bdata.obs.index, columns=bdata.var.index)
    else:
        data = pd.DataFrame(
            bdata.layers[layer], index=bdata.obs.index, columns=bdata.var.index
        )

    # keep only columns that are required
    if markers is not None:
        if isinstance(markers, str):
            markers = [markers]
        # subset the list
        data = data[markers]

    # auto identify rows and columns in the grid plot
    def calculate_grid_dimensions(num_items, num_columns=None):
        """
        Calculates the number of rows and columns for a square grid
        based on the number of items.
        """
        if num_columns is None:
            num_rows_columns = int(math.ceil(math.sqrt(num_items)))
            return num_rows_columns, num_rows_columns
        else:
            num_rows = int(math.ceil(num_items / num_columns))
            return num_rows, num_columns

    if plotGrid is False:
        # Create a figure and axis object
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        # Loop through each column in the DataFrame and plot a KDE with the
        # user-defined color or the default color (grey)
        if color is None:
            for column in data.columns:
                data[column].plot.kde(ax=ax, label=column)
        else:
            for column in data.columns:
                c = color.get(column, 'grey')
                data[column].plot.kde(ax=ax, label=column, color=c)
        ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), fontsize=fontsize)
        ax.tick_params(axis='both', which='major', width=1, labelsize=fontsize)
        plt.tight_layout()
        if xticks is not None:
            ax.set_xticks(xticks)
            ax.set_xticklabels([str(x) for x in xticks])

        if vline == 'auto':
            ax.axvline((data[column].max() + data[column].min()) / 2, color='black')
        elif vline is None:
            pass
        else:
            ax.axvline(vline, color='black')

        # save figure
        if outputDir is not None:
            plt.savefig(pathlib.Path(outputDir) / outputFileName)

    else:
        # calculate the number of rows and columns
        num_rows, num_cols = calculate_grid_dimensions(
            len(data.columns), num_columns=ncols
        )

        # set colors
        if color is None:
            # Define a color cycle of 10 colors
            color_cycle = itertools.cycle(
                plt.rcParams['axes.prop_cycle'].by_key()['color']
            )
            # Assign a different color to each column
            color = {col: next(color_cycle) for col in data.columns}

        # Set the size of the figure
        fig, axes = plt.subplots(
            nrows=num_rows, ncols=num_cols, figsize=figsize, dpi=dpi
        )
        axes = np.atleast_2d(axes)
        # Set the spacing between subplots
        # fig.subplots_adjust(bottom=0.1, hspace=0.1)

        # Loop through each column in the DataFrame and plot a KDE with the
        # user-defined color or the default color (grey) in the corresponding subplot
        for i, column in enumerate(data.columns):
            c = color.get(column, 'grey')
            row_idx = i // num_cols
            col_idx = i % num_cols
            data[column].plot.kde(ax=axes[row_idx, col_idx], label=column, color=c)
            axes[row_idx, col_idx].set_title(column)
            axes[row_idx, col_idx].tick_params(
                axis='both', which='major', width=1, labelsize=fontsize
            )
            axes[row_idx, col_idx].set_ylabel('')

            if vline == 'auto':
                axes[row_idx, col_idx].axvline(
                    (data[column].max() + data[column].min()) / 2, color='black'
                )
            elif vline is None:
                pass
            else:
                axes[row_idx, col_idx].axvline(vline, color='black')

            if xticks is not None:
                axes[row_idx, col_idx].set_xticks(xticks)
                axes[row_idx, col_idx].set_xticklabels([str(x) for x in xticks])

        # Remove any empty subplots
        num_plots = len(data.columns)
        for i in range(num_plots, num_rows * num_cols):
            row_idx = i // num_cols
            col_idx = i % num_cols
            fig.delaxes(axes[row_idx, col_idx])

        # Set font size for tick labels on both axes
        plt.tick_params(axis='both', labelsize=fontsize)
        plt.tight_layout()

        # Save the figure to a file
        if saveDir:
            if not os.path.exists(saveDir):
                os.makedirs(saveDir)
            full_path = os.path.join(saveDir, fileName)
            plt.savefig(full_path, dpi=300)
            plt.close()
            print(f"Saved heatmap to {full_path}")
        else:
            plt.show()