Skip to content

groupCorrelation

Short Description

The sm.pl.groupCorrelation function calculates and visualizes the correlation between group abundances across various conditions within an AnnData object. Customizable features such as normalization, hierarchical clustering, and manual ordering are available.

Function

groupCorrelation(adata, groupBy, condition, normalize=False, subsetGroups=None, orderRow=None, orderColumn=None, clusterRows=True, clusterColumns=True, cmap='vlag', figsize=None, overlayValues=False, fontSize=10, fontColor='black', fileName='groupCorrelation.pdf', saveDir=None, **kwargs)

Parameters:

Name Type Description Default
adata AnnData or str

An AnnData object containing the dataset, or a string path to an AnnData file to be loaded.

required
groupBy str

The column in adata.obs used for defining groups.

required
condition str

The column in adata.obs that distinguishes different conditions or samples.

required
normalize bool

If True, apply z-score normalization to the group counts across conditions.

False
subsetGroups list of str

A list specifying a subset of groups to include in the analysis. If None, all groups are included.

None
orderRow list of str

Custom order for the rows in the heatmap. If None, the order is determined by clustering or the original group order.

None
orderColumn list of str

Custom order for the columns in the heatmap.

None
clusterRows bool

Whether to apply hierarchical clustering to rows.

True
clusterColumns bool

Whether to apply hierarchical clustering to columns.

True
cmap str

The colormap for the heatmap.

'vlag'
figsize tuple of float

The size of the figure to create (width, height). If None, the size is inferred.

None
overlayValues bool

If True, overlays the correlation coefficient values on the heatmap.

False
fontSize int

Font size for overlay values.

10
fontColor str

Color of the font used for overlay values.

'black'
fileName str

Name of the file to save the heatmap. Relevant only if saveDir is specified.

'groupCorrelation.pdf'
saveDir str

Directory to save the generated heatmap. If None, the heatmap is not saved.

None

Returns:

Name Type Description
plot matplotlib

Displays or saves a heatmap visualizing the correlation between specified groups.

Example
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Basic usage with auto-detected conditions and groups
sm.pl.groupCorrelation(adata, groupBy='cell_type', condition='patient_id')

# Normalized group counts with specific groups and custom clustering disabled
sm.pl.groupCorrelation(adata, groupBy='cell_type', condition='patient_id', normalize=True,
                 subsetGroups=['B cells', 'T cells'], clusterRows=False, clusterColumns=False)

# Using custom ordering and overlaying values with specified font size and color
sm.pl.groupCorrelation(adata, groupBy='cell_type', condition='patient_id', overlayValues=True,
                 orderRow=['T cells', 'B cells'], fontSize=12, fontColor='blue',
                 saveDir='/path/to/results', fileName='customGroupCorrelation.pdf')
Source code in scimap/plotting/groupCorrelation.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def groupCorrelation(adata, 
                     groupBy, 
                     condition,
                     normalize=False,
                     subsetGroups=None, 
                     orderRow=None, 
                     orderColumn=None, 
                     clusterRows=True, 
                     clusterColumns=True, 
                     cmap='vlag', 
                     figsize=None, 
                     overlayValues=False, 
                     fontSize=10,
                     fontColor='black',
                     fileName='groupCorrelation.pdf',
                     saveDir=None,
                     **kwargs):

    """
Parameters:
    adata (AnnData or str): 
        An AnnData object containing the dataset, or a string path to an AnnData file to be loaded.

    groupBy (str): 
        The column in `adata.obs` used for defining groups.

    condition (str): 
        The column in `adata.obs` that distinguishes different conditions or samples.

    normalize (bool, optional): 
        If True, apply z-score normalization to the group counts across conditions. 

    subsetGroups (list of str, optional): 
        A list specifying a subset of groups to include in the analysis. If None, all groups are included.

    orderRow (list of str, optional): 
        Custom order for the rows in the heatmap. If None, the order is determined by clustering or the original group order.

    orderColumn (list of str, optional): 
        Custom order for the columns in the heatmap.

    clusterRows (bool, optional): 
        Whether to apply hierarchical clustering to rows. 

    clusterColumns (bool, optional): 
        Whether to apply hierarchical clustering to columns. 

    cmap (str, optional): 
        The colormap for the heatmap. 

    figsize (tuple of float, optional): 
        The size of the figure to create (width, height). If None, the size is inferred.

    overlayValues (bool, optional): 
        If True, overlays the correlation coefficient values on the heatmap. 

    fontSize (int, optional): 
        Font size for overlay values. 

    fontColor (str, optional): 
        Color of the font used for overlay values. 

    fileName (str, optional): 
        Name of the file to save the heatmap. Relevant only if `saveDir` is specified. 

    saveDir (str, optional): 
        Directory to save the generated heatmap. If None, the heatmap is not saved.

Returns:
        plot (matplotlib): 
            Displays or saves a heatmap visualizing the correlation between specified groups.

Example:
    ```python

    # Basic usage with auto-detected conditions and groups
    sm.pl.groupCorrelation(adata, groupBy='cell_type', condition='patient_id')

    # Normalized group counts with specific groups and custom clustering disabled
    sm.pl.groupCorrelation(adata, groupBy='cell_type', condition='patient_id', normalize=True,
                     subsetGroups=['B cells', 'T cells'], clusterRows=False, clusterColumns=False)

    # Using custom ordering and overlaying values with specified font size and color
    sm.pl.groupCorrelation(adata, groupBy='cell_type', condition='patient_id', overlayValues=True,
                     orderRow=['T cells', 'B cells'], fontSize=12, fontColor='blue',
                     saveDir='/path/to/results', fileName='customGroupCorrelation.pdf')
    ```

"""

    # Load adata if a path is provided
    if isinstance(adata, str):
        adata = ad.read_h5ad(adata)

    # Calculate group counts
    group_counts = adata.obs.groupby([condition, groupBy]).size().unstack(fill_value=0)

    # Subset groups if needed
    if subsetGroups:
        group_counts = group_counts[subsetGroups]

    # Normalize if requested
    if normalize:
        group_counts = group_counts.apply(zscore, axis=0)

    # Calculate correlation
    corr_matrix = group_counts.corr()

    # var_names for axis labels, directly from group_counts columns
    var_names = group_counts.columns.tolist()

    # Manual ordering takes precedence over clustering
    if orderRow and clusterRows:
        warnings.warn("Both orderRow and clusterRows were provided. Proceeding with orderRow and ignoring clusterRows.")
        clusterRows = False
    if orderColumn and clusterColumns:
        warnings.warn("Both orderColumn and clusterColumns were provided. Proceeding with orderColumn and ignoring clusterColumns.")
        clusterColumns = False

    # Apply manual ordering or clustering
    if orderRow:
        row_order = [var_names.index(name) for name in orderRow]
    else:
        row_order = range(len(var_names))  # Default order if no manual ordering
        if clusterRows:
            linkage_row = linkage(pdist(corr_matrix, 'euclidean'), method='average')
            row_order = dendrogram(linkage_row, no_plot=True)['leaves']

    if orderColumn:
        col_order = [var_names.index(name) for name in orderColumn]
    else:
        col_order = range(len(var_names))  # Default order if no manual ordering
        if clusterColumns:
            linkage_col = linkage(pdist(corr_matrix.T, 'euclidean'), method='average')
            col_order = dendrogram(linkage_col, no_plot=True)['leaves']

    # Reorder the matrix based on row_order and col_order
    corr_matrix = corr_matrix.iloc[row_order, col_order]

    # Plotting
    if figsize is None:
        figsize_width = max(10, len(corr_matrix.columns) * 0.5)
        figsize_height = max(8, len(corr_matrix.index) * 0.5)
        figsize = (figsize_width, figsize_height)

    plt.figure(figsize=figsize)
    im = plt.imshow(corr_matrix, cmap=cmap, aspect='auto', **kwargs)
    plt.colorbar(im)

    if overlayValues:
        for i in range(len(row_order)):
            for j in range(len(col_order)):
                plt.text(j, i, f"{corr_matrix.iloc[i, j]:.2f}", ha="center", va="center", color=fontColor,fontsize=fontSize)

    # Set tick labels
    plt.xticks(ticks=np.arange(len(col_order)), labels=[var_names[i] for i in col_order], rotation=90)
    plt.yticks(ticks=np.arange(len(row_order)), labels=[var_names[i] for i in row_order])

    plt.tight_layout()

    # Save or show the figure
    if saveDir and fileName:
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        full_path = os.path.join(saveDir, fileName)
        plt.savefig(full_path, dpi=300)
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        print(f"Saved heatmap to {full_path}")
    else:
        plt.show()