Skip to content

NodeStats Modules

Perform Crossing Region Processing and Analysis.

ImageDict

Bases: TypedDict

Dictionary containing the image information.

Source code in topostats\tracing\nodestats.py
class ImageDict(TypedDict):
    """Dictionary containing the image information."""

    nodes: dict[str, dict[str, npt.NDArray[np.int32]]]
    grain: dict[str, npt.NDArray[np.int32] | dict[str, npt.NDArray[np.int32]]]

MatchedBranch

Bases: TypedDict

Dictionary containing the matched branches.

matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. The ordered coordinates of the branch. - "heights" : npt.NDArray[np.number]. Heights of the branch coordinates. - "distances" : npt.NDArray[np.number]. Distances of the branch coordinates. - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branch. - "angles" : np.float64. The initial direction angle of the branch, added in later steps.

Source code in topostats\tracing\nodestats.py
class MatchedBranch(TypedDict):
    """
    Dictionary containing the matched branches.

    matched_branches: dict[int, dict[str, npt.NDArray[np.number]]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "ordered_coords" : npt.NDArray[np.int32]. The ordered coordinates of the branch.
        - "heights" : npt.NDArray[np.number]. Heights of the branch coordinates.
        - "distances" : npt.NDArray[np.number]. Distances of the branch coordinates.
        - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branch.
        - "angles" : np.float64. The initial direction angle of the branch, added in later steps.
    """

    ordered_coords: npt.NDArray[np.int32]
    heights: npt.NDArray[np.number]
    distances: npt.NDArray[np.number]
    fwhm: dict[str, np.float64 | tuple[np.float64]]
    angles: np.float64 | None

NodeDict

Bases: TypedDict

Dictionary containing the node information.

Source code in topostats\tracing\nodestats.py
class NodeDict(TypedDict):
    """Dictionary containing the node information."""

    error: bool
    pixel_to_nm_scaling: np.float64
    branch_stats: dict[int, MatchedBranch] | None
    node_coords: npt.NDArray[np.int32] | None
    confidence: np.float64 | None

nodeStats

Class containing methods to find and analyse the nodes/crossings within a grain.

Parameters:

Name Type Description Default
filename str

The name of the file being processed. For logging purposes.

required
image NDArray

The array of pixels.

required
mask NDArray

The binary segmentation mask.

required
smoothed_mask NDArray

A smoothed version of the bianary segmentation mask.

required
skeleton NDArray

A binary single-pixel wide mask of objects in the 'image'.

required
pixel_to_nm_scaling float32

The pixel to nm scaling factor.

required
n_grain int

The grain number.

required
node_joining_length float

The length over which to join skeletal intersections to be counted as one crossing.

required
node_joining_length float

The distance over which to join nearby odd-branched nodes.

required
node_extend_dist float

The distance under which to join odd-branched node regions.

required
branch_pairing_length float

The length from the crossing point to pair and trace, obtaining FWHM's.

required
pair_odd_branches bool

Whether to try and pair odd-branched nodes.

required
Source code in topostats\tracing\nodestats.py
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
class nodeStats:
    """
    Class containing methods to find and analyse the nodes/crossings within a grain.

    Parameters
    ----------
    filename : str
        The name of the file being processed. For logging purposes.
    image : npt.npt.NDArray
        The array of pixels.
    mask : npt.npt.NDArray
        The binary segmentation mask.
    smoothed_mask : npt.NDArray
        A smoothed version of the bianary segmentation mask.
    skeleton : npt.NDArray
        A binary single-pixel wide mask of objects in the 'image'.
    pixel_to_nm_scaling : np.float32
        The pixel to nm scaling factor.
    n_grain : int
        The grain number.
    node_joining_length : float
        The length over which to join skeletal intersections to be counted as one crossing.
    node_joining_length : float
        The distance over which to join nearby odd-branched nodes.
    node_extend_dist : float
        The distance under which to join odd-branched node regions.
    branch_pairing_length : float
        The length from the crossing point to pair and trace, obtaining FWHM's.
    pair_odd_branches : bool
        Whether to try and pair odd-branched nodes.
    """

    def __init__(
        self,
        filename: str,
        image: npt.NDArray,
        mask: npt.NDArray,
        smoothed_mask: npt.NDArray,
        skeleton: npt.NDArray,
        pixel_to_nm_scaling: np.float64,
        n_grain: int,
        node_joining_length: float,
        node_extend_dist: float,
        branch_pairing_length: float,
        pair_odd_branches: bool,
    ) -> None:
        """
        Initialise the nodeStats class.

        Parameters
        ----------
        filename : str
            The name of the file being processed. For logging purposes.
        image : npt.NDArray
            The array of pixels.
        mask : npt.NDArray
            The binary segmentation mask.
        smoothed_mask : npt.NDArray
            A smoothed version of the bianary segmentation mask.
        skeleton : npt.NDArray
            A binary single-pixel wide mask of objects in the 'image'.
        pixel_to_nm_scaling : float
            The pixel to nm scaling factor.
        n_grain : int
            The grain number.
        node_joining_length : float
            The length over which to join skeletal intersections to be counted as one crossing.
        node_joining_length : float
            The distance over which to join nearby odd-branched nodes.
        node_extend_dist : float
            The distance under which to join odd-branched node regions.
        branch_pairing_length : float
            The length from the crossing point to pair and trace, obtaining FWHM's.
        pair_odd_branches : bool
            Whether to try and pair odd-branched nodes.
        """
        self.filename = filename
        self.image = image
        self.mask = mask
        self.smoothed_mask = smoothed_mask  # only used to average traces
        self.skeleton = skeleton
        self.pixel_to_nm_scaling = pixel_to_nm_scaling
        self.n_grain = n_grain
        self.node_joining_length = node_joining_length
        self.node_extend_dist = node_extend_dist / self.pixel_to_nm_scaling
        self.branch_pairing_length = branch_pairing_length
        self.pair_odd_branches = pair_odd_branches

        self.conv_skelly = np.zeros_like(self.skeleton)
        self.connected_nodes = np.zeros_like(self.skeleton)
        self.all_connected_nodes = np.zeros_like(self.skeleton)
        self.whole_skel_graph: nx.classes.graph.Graph | None = None
        self.node_centre_mask = np.zeros_like(self.skeleton)

        self.metrics = {
            "num_crossings": np.int64(0),
            "avg_crossing_confidence": None,
            "min_crossing_confidence": None,
        }

        self.node_dicts: dict[str, NodeDict] = {}
        self.image_dict: ImageDict = {
            "nodes": {},
            "grain": {
                "grain_image": self.image,
                "grain_mask": self.mask,
                "grain_skeleton": self.skeleton,
            },
        }

        self.full_dict = {}
        self.mol_coords = {}
        self.visuals = {}
        self.all_visuals_img = None

    def get_node_stats(self) -> tuple[dict, dict]:
        """
        Run the workflow to obtain the node statistics.

        .. code-block:: RST

            node_dict key structure:  <grain_number>
                                        â””-> <node_number>
                                            |-> 'error'
                                            â””-> 'node_coords'
                                            â””-> 'branch_stats'
                                                â””-> <branch_number>
                                                    |-> 'ordered_coords'
                                                    |-> 'heights'
                                                    |-> 'gaussian_fit'
                                                    |-> 'fwhm'
                                                    â””-> 'angles'

            image_dict key structure:  'nodes'
                                            <node_number>
                                                |-> 'node_area_skeleton'
                                                |-> 'node_branch_mask'
                                                â””-> 'node_avg_mask
                                        'grain'
                                            |-> 'grain_image'
                                            |-> 'grain_mask'
                                            â””-> 'grain_skeleton'

        Returns
        -------
        tuple[dict, dict]
            Dictionaries of the node_information and images.
        """
        LOGGER.debug(f"Node Stats - Processing Grain: {self.n_grain}")
        self.conv_skelly = convolve_skeleton(self.skeleton)
        if len(self.conv_skelly[self.conv_skelly == 3]) != 0:  # check if any nodes
            LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.")
            # convolve to see crossing and end points
            # self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image)
            # reset skeleton var as tidy branches may have modified it
            self.skeleton = np.where(self.conv_skelly != 0, 1, 0)
            self.image_dict["grain"]["grain_skeleton"] = self.skeleton
            # get graph of skeleton
            self.whole_skel_graph = self.skeleton_image_to_graph(self.skeleton)
            # connect the close nodes
            LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.")
            self.connected_nodes = self.connect_close_nodes(self.conv_skelly, node_width=self.node_joining_length)
            # connect the odd-branch nodes
            self.connected_nodes = self.connect_extended_nodes_nearest(
                self.connected_nodes, node_extend_dist=self.node_extend_dist
            )
            # obtain a mask of node centers and their count
            self.node_centre_mask = self.highlight_node_centres(self.connected_nodes)
            # Begin the hefty crossing analysis
            LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.")
            self.analyse_nodes(max_branch_length=self.branch_pairing_length)
            self.compile_metrics()
        else:
            LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.")
        return self.node_dicts, self.image_dict
        # self.all_visuals_img = dnaTrace.concat_images_in_dict(self.image.shape, self.visuals)

    @staticmethod
    def skeleton_image_to_graph(skeleton: npt.NDArray) -> nx.classes.graph.Graph:
        """
        Convert a skeletonised mask into a Graph representation.

        Graphs conserve the coordinates via the node label.

        Parameters
        ----------
        skeleton : npt.NDArray
            A binary single-pixel wide mask, or result from conv_skelly().

        Returns
        -------
        nx.classes.graph.Graph
            A networkX graph connecting the pixels in the skeleton to their neighbours.
        """
        skeImPos = np.argwhere(skeleton).T
        g = nx.Graph()
        neigh = np.array([[0, 1], [0, -1], [1, 0], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]])

        for idx in range(skeImPos[0].shape[0]):
            for neighIdx in range(neigh.shape[0]):
                curNeighPos = skeImPos[:, idx] + neigh[neighIdx]
                if np.any(curNeighPos < 0) or np.any(curNeighPos >= skeleton.shape):
                    continue
                if skeleton[curNeighPos[0], curNeighPos[1]] > 0:
                    idx_coord = skeImPos[0, idx], skeImPos[1, idx]
                    curNeigh_coord = curNeighPos[0], curNeighPos[1]
                    # assign lower weight to nodes if not a binary image
                    if skeleton[idx_coord] == 3 and skeleton[curNeigh_coord] == 3:
                        weight = 0
                    else:
                        weight = 1
                    g.add_edge(idx_coord, curNeigh_coord, weight=weight)
        g.graph["physicalPos"] = skeImPos.T
        return g

    @staticmethod
    def graph_to_skeleton_image(g: nx.Graph, im_shape: tuple[int]) -> npt.NDArray:
        """
        Convert the skeleton graph back to a binary image.

        Parameters
        ----------
        g : nx.Graph
            Graph with coordinates as node labels.
        im_shape : tuple[int]
            The shape of the image to dump.

        Returns
        -------
        npt.NDArray
            Skeleton binary image from the graph representation.
        """
        im = np.zeros(im_shape)
        for node in g:
            im[node] = 1

        return im

    def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> npt.NDArray:
        """
        Wrangle distant connected nodes back towards the main cluster.

        Works by filling and reskeletonising soely the node areas.

        Parameters
        ----------
        connect_node_mask : npt.NDArray
            The connected node mask - a skeleton where node regions = 3, endpoints = 2, and skeleton = 1.
        image : npt.NDArray
            The intensity image.

        Returns
        -------
        npt.NDArray
            The wrangled connected_node_mask.
        """
        new_skeleton = np.where(connect_node_mask != 0, 1, 0)
        labeled_nodes = label(np.where(connect_node_mask == 3, 1, 0))
        for node_num in range(1, labeled_nodes.max() + 1):
            solo_node = np.where(labeled_nodes == node_num, 1, 0)
            coords = np.argwhere(solo_node == 1)
            node_centre = coords.mean(axis=0).astype(np.int32)
            node_wid = coords[:, 0].max() - coords[:, 0].min() + 2  # +2 so always 2 by default
            node_len = coords[:, 1].max() - coords[:, 1].min() + 2  # +2 so always 2 by default
            overflow = int(10 / self.pixel_to_nm_scaling) if int(10 / self.pixel_to_nm_scaling) != 0 else 1
            # grain mask fill
            new_skeleton[
                node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow,
                node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow,
            ] = self.mask[
                node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow,
                node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow,
            ]
        # remove any artifacts of the grain caught in the overflow areas
        new_skeleton = self.keep_biggest_object(new_skeleton)
        # Re-skeletonise
        new_skeleton = getSkeleton(image, new_skeleton, method="topostats", height_bias=0.6).get_skeleton()
        # new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton(
        #     {"method": "topostats", "max_length": -1}
        # )
        new_skeleton = prune_skeleton(
            image, new_skeleton, self.pixel_to_nm_scaling, **{"method": "topostats", "max_length": -1}
        )
        # cleanup around nibs
        new_skeleton = getSkeleton(image, new_skeleton, method="zhang").get_skeleton()
        # might also need to remove segments that have squares connected

        return convolve_skeleton(new_skeleton)

    @staticmethod
    def keep_biggest_object(mask: npt.NDArray) -> npt.NDArray:
        """
        Retain the largest object in a binary mask.

        Parameters
        ----------
        mask : npt.NDArray
            Binary mask.

        Returns
        -------
        npt.NDArray
            A binary mask with only one object.
        """
        labelled_mask = label(mask)
        idxs, counts = np.unique(mask, return_counts=True)
        try:
            max_idx = idxs[np.argmax(counts[1:]) + 1]
            return np.where(labelled_mask == max_idx, 1, 0)
        except ValueError as e:
            LOGGER.debug(f"{e}: mask is empty.")
            return mask

    def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray:
        """
        Connect nodes within the 'node_width' boundary distance.

        This labels them as part of the same node.

        Parameters
        ----------
        conv_skelly : npt.NDArray
            A labeled skeleton image with skeleton = 1, endpoints = 2, crossing points =3.
        node_width : float
            The width of the dna in the grain, used to connect close nodes.

        Returns
        -------
        np.ndarray
            The skeleton (label=1) with close nodes connected (label=3).
        """
        self.connected_nodes = conv_skelly.copy()
        nodeless = conv_skelly.copy()
        nodeless[(nodeless == 3) | (nodeless == 2)] = 0  # remove node & termini points
        nodeless_labels = label(nodeless)
        for i in range(1, nodeless_labels.max() + 1):
            if nodeless[nodeless_labels == i].size < (node_width / self.pixel_to_nm_scaling):
                # maybe also need to select based on height? and also ensure small branches classified
                self.connected_nodes[nodeless_labels == i] = 3

        return self.connected_nodes

    def highlight_node_centres(self, mask: npt.NDArray) -> npt.NDArray:
        """
        Calculate the node centres based on height and re-plot on the mask.

        Parameters
        ----------
        mask : npt.NDArray
            2D array with background = 0, skeleton = 1, endpoints = 2, node_centres = 3.

        Returns
        -------
        npt.NDArray
            2D array with the highest node coordinate for each node labeled as 3.
        """
        small_node_mask = mask.copy()
        small_node_mask[mask == 3] = 1  # remap nodes to skeleton
        big_nodes = mask.copy()
        big_nodes = np.where(mask == 3, 1, 0)  # remove non-nodes & set nodes to 1
        big_node_mask = label(big_nodes)

        for i in np.delete(np.unique(big_node_mask), 0):  # get node indices
            centre = np.unravel_index((self.image * (big_node_mask == i).astype(int)).argmax(), self.image.shape)
            small_node_mask[centre] = 3

        return small_node_mask

    def connect_extended_nodes_nearest(
        self, connected_nodes: npt.NDArray, node_extend_dist: float = -1
    ) -> npt.NDArray[np.int32]:
        """
        Extend the odd branched nodes to other odd branched nodes within the 'extend_dist' threshold.

        Parameters
        ----------
        connected_nodes : npt.NDArray
            A 2D array representing the network with background = 0, skeleton = 1, endpoints = 2,
            node_centres = 3.
        node_extend_dist : int | float, optional
            The distance over which to connect odd-branched nodes, by default -1 for no-limit.

        Returns
        -------
        npt.NDArray[np.int32]
            Connected nodes array with odd-branched nodes connected.
        """
        just_nodes = np.where(connected_nodes == 3, 1, 0)  # remove branches & termini points
        labelled_nodes = label(just_nodes)

        just_branches = np.where(connected_nodes == 1, 1, 0)  # remove node & termini points
        just_branches[connected_nodes == 1] = labelled_nodes.max() + 1
        labelled_branches = label(just_branches)

        nodes_with_branch_starting_coords = find_branches_for_nodes(
            network_array_representation=connected_nodes,
            labelled_nodes=labelled_nodes,
            labelled_branches=labelled_branches,
        )

        # If there is only one node, then there is no need to connect the nodes since there is nothing to
        # connect it to. Return the original connected_nodes instead.
        if len(nodes_with_branch_starting_coords) <= 1:
            self.connected_nodes = connected_nodes
            return self.connected_nodes

        assert self.whole_skel_graph is not None, "Whole skeleton graph is not defined."  # for type safety
        shortest_node_dists, shortest_dists_branch_idxs, _shortest_dist_coords = calculate_shortest_branch_distances(
            nodes_with_branch_starting_coords=nodes_with_branch_starting_coords,
            whole_skeleton_graph=self.whole_skel_graph,
        )

        # Matches is an Nx2 numpy array of indexes of the best matching nodes.
        # Eg: np.array([[1, 0], [2, 3]]) means that the best matching nodes are
        # node 1 and node 0, and node 2 and node 3.
        matches: npt.NDArray[np.int32] = self.best_matches(shortest_node_dists, max_weight_matching=False)

        # Connect the nodes by their best matches, using the shortest distances between their branch starts.
        connected_nodes = connect_best_matches(
            network_array_representation=connected_nodes,
            whole_skeleton_graph=self.whole_skel_graph,
            match_indexes=matches,
            shortest_distances_between_nodes=shortest_node_dists,
            shortest_distances_branch_indexes=shortest_dists_branch_idxs,
            emanating_branch_starts_by_node=nodes_with_branch_starting_coords,
            extend_distance=node_extend_dist,
        )

        self.connected_nodes = connected_nodes
        return self.connected_nodes

    @staticmethod
    def find_branch_starts(reduced_node_image: npt.NDArray) -> npt.NDArray:
        """
        Find the coordinates where the branches connect to the node region through binary dilation of the node.

        Parameters
        ----------
        reduced_node_image : npt.NDArray
            A 2D numpy array containing a single node region (=3) and its connected branches (=1).

        Returns
        -------
        npt.NDArray
            Coordinate array of pixels next to crossing points (=3 in input).
        """
        node = np.where(reduced_node_image == 3, 1, 0)
        nodeless = np.where(reduced_node_image == 1, 1, 0)
        thick_node = binary_dilation(node, structure=np.ones((3, 3)))

        return np.argwhere(thick_node * nodeless == 1)

    # pylint: disable=too-many-locals
    def analyse_nodes(self, max_branch_length: float = 20) -> None:
        """
        Obtain the main analyses for the nodes of a single molecule along the 'max_branch_length'(nm) from the node.

        Parameters
        ----------
        max_branch_length : float
            The side length of the box around the node to analyse (in nm).
        """
        # Get coordinates of nodes
        # This is a numpy array of coords, shape Nx2
        assert self.node_centre_mask is not None, "Node centre mask is not defined."
        node_coords: npt.NDArray[np.int32] = np.argwhere(self.node_centre_mask.copy() == 3)

        # Check whether average trace resides inside the grain mask
        # Checks if we dilate the skeleton once or twice, then all the pixels should fit in the grain mask
        dilate = binary_dilation(self.skeleton, iterations=2)
        # This flag determines whether to use average of 3 traces in calculation of FWHM
        average_trace_advised = dilate[self.smoothed_mask == 1].sum() == dilate.sum()
        LOGGER.debug(f"[{self.filename}] : Branch height traces will be averaged: {average_trace_advised}")

        # Iterate over the nodes and analyse the branches
        matched_branches = None
        branch_image = None
        avg_image = np.zeros_like(self.image)
        real_node_count = 0
        for node_no, (node_x, node_y) in enumerate(node_coords):
            unmatched_branches = {}
            error = False

            # Get branches relevant to the node
            max_length_px = max_branch_length / (self.pixel_to_nm_scaling * 1)
            reduced_node_area: npt.NDArray[np.int32] = nodeStats.only_centre_branches(
                self.connected_nodes, np.array([node_x, node_y])
            )
            # Reduced skel graph is a networkx graph of the reduced node area.
            reduced_skel_graph: nx.classes.graph.Graph = nodeStats.skeleton_image_to_graph(reduced_node_area)

            # Binarise the reduced node area
            branch_mask = reduced_node_area.copy()
            branch_mask[branch_mask == 3] = 0
            branch_mask[branch_mask == 2] = 1
            node_coords = np.argwhere(reduced_node_area == 3)

            # Find the starting coordinates of any branches connected to the node
            branch_start_coords = self.find_branch_starts(reduced_node_area)

            # Stop processing if nib (node has 2 branches)
            if branch_start_coords.shape[0] <= 2:
                LOGGER.debug(
                    f"node {node_no} has only two branches - skipped & nodes removed.{len(node_coords)}"
                    "pixels in nib node."
                )
            else:
                try:
                    real_node_count += 1
                    LOGGER.debug(f"Node: {real_node_count}")

                    # Analyse the node branches
                    (
                        pairs,
                        matched_branches,
                        ordered_branches,
                        masked_image,
                        branch_under_over_order,
                        confidence,
                        singlet_branch_vectors,
                    ) = nodeStats.analyse_node_branches(
                        p_to_nm=self.pixel_to_nm_scaling,
                        reduced_node_area=reduced_node_area,
                        branch_start_coords=branch_start_coords,
                        max_length_px=max_length_px,
                        reduced_skeleton_graph=reduced_skel_graph,
                        image=self.image,
                        average_trace_advised=average_trace_advised,
                        node_coord=(node_x, node_y),
                        pair_odd_branches=self.pair_odd_branches,
                        filename=self.filename,
                        resolution_threshold=np.float64(1000 / 512),
                    )

                    # Add the analysed branches to the labelled image
                    branch_image, avg_image = nodeStats.add_branches_to_labelled_image(
                        branch_under_over_order=branch_under_over_order,
                        matched_branches=matched_branches,
                        masked_image=masked_image,
                        branch_start_coords=branch_start_coords,
                        ordered_branches=ordered_branches,
                        pairs=pairs,
                        average_trace_advised=average_trace_advised,
                        image_shape=(self.image.shape[0], self.image.shape[1]),
                    )

                    # Calculate crossing angles of unpaired branches and add to stats dict
                    nodestats_calc_singlet_angles_result = nodeStats.calc_angles(np.asarray(singlet_branch_vectors))
                    angles_between_singlet_branch_vectors: npt.NDArray[np.float64] = (
                        nodestats_calc_singlet_angles_result[0]
                    )

                    for branch_index, angle in enumerate(angles_between_singlet_branch_vectors):
                        unmatched_branches[branch_index] = {"angles": angle}

                    # Get the vector of each branch based on ordered_coords. Ordered_coords is only the first N nm
                    # of the branch so this is just a general vibe on what direction a branch is going.
                    if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches:
                        vectors: list[npt.NDArray[np.float64]] = []
                        for _, values in matched_branches.items():
                            vectors.append(nodeStats.get_vector(values["ordered_coords"], np.array([node_x, node_y])))
                        # Calculate angles between the vectors
                        nodestats_calc_angles_result = nodeStats.calc_angles(np.asarray(vectors))
                        angles_between_vectors_along_branch: npt.NDArray[np.float64] = nodestats_calc_angles_result[0]
                        for branch_index, angle in enumerate(angles_between_vectors_along_branch):
                            if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches:
                                matched_branches[branch_index]["angles"] = angle
                    else:
                        self.image_dict["grain"]["grain_skeleton"][node_coords[:, 0], node_coords[:, 1]] = 0

                    # Eg: length 2 array: [array([ nan, 79.00]), array([79.00, 0.0])]
                    # angles_between_vectors_along_branch

                except ResolutionError:
                    LOGGER.debug(f"Node stats skipped as resolution too low: {self.pixel_to_nm_scaling}nm per pixel")
                    error = True

                self.node_dicts[f"node_{real_node_count}"] = {
                    "error": error,
                    "pixel_to_nm_scaling": self.pixel_to_nm_scaling,
                    "branch_stats": matched_branches,
                    "unmatched_branch_stats": unmatched_branches,
                    "node_coords": node_coords,
                    "confidence": confidence,
                }

                assert reduced_node_area is not None, "Reduced node area is not defined."
                assert branch_image is not None, "Branch image is not defined."
                assert avg_image is not None, "Average image is not defined."
                node_images_dict: dict[str, npt.NDArray[np.int32]] = {
                    "node_area_skeleton": reduced_node_area,
                    "node_branch_mask": branch_image,
                    "node_avg_mask": avg_image,
                }
                self.image_dict["nodes"][f"node_{real_node_count}"] = node_images_dict

            self.all_connected_nodes[self.connected_nodes != 0] = self.connected_nodes[self.connected_nodes != 0]

    # pylint: disable=too-many-arguments
    @staticmethod
    def add_branches_to_labelled_image(
        branch_under_over_order: npt.NDArray[np.int32],
        matched_branches: dict[int, MatchedBranch],
        masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]],
        branch_start_coords: npt.NDArray[np.int32],
        ordered_branches: list[npt.NDArray[np.int32]],
        pairs: npt.NDArray[np.int32],
        average_trace_advised: bool,
        image_shape: tuple[int, int],
    ) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
        """
        Add branches to a labelled image.

        Parameters
        ----------
        branch_under_over_order : npt.NDArray[np.int32]
            The order of the branches.
        matched_branches : dict[int, dict[str, MatchedBranch]]
            Dictionary where the key is the index of the pair and the value is a dictionary containing the following
            keys:
            - "ordered_coords" : npt.NDArray[np.int32].
            - "heights" : npt.NDArray[np.number]. Heights of the branches.
            - "distances" :
            - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
        masked_image : dict[int, dict[str, npt.NDArray[np.bool_]]]
            Dictionary where the key is the index of the pair and the value is a dictionary containing the following
            keys:
            - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
        branch_start_coords : npt.NDArray[np.int32]
            An Nx2 numpy array of the coordinates of the branches connected to the node.
        ordered_branches : list[npt.NDArray[np.int32]]
            List of numpy arrays of ordered branch coordinates.
        pairs : npt.NDArray[np.int32]
            Nx2 numpy array of pairs of branches that are matched through a node.
        average_trace_advised : bool
            Flag to determine whether to use the average trace.
        image_shape : tuple[int]
            The shape of the image, to create a mask from.

        Returns
        -------
        tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]
            The branch image and the average image.
        """
        branch_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32)
        avg_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32)

        for i, branch_index in enumerate(branch_under_over_order):
            branch_coords = matched_branches[branch_index]["ordered_coords"]

            # Add the matched branch to the image, starting at index 1
            branch_image[branch_coords[:, 0], branch_coords[:, 1]] = i + 1
            if average_trace_advised:
                # For type safety, check if avg_image is None and skip if so.
                # This is because the type hinting does not allow for None in the array.
                avg_image[masked_image[branch_index]["avg_mask"] != 0] = i + 1

        # Determine branches that were not able to be paired
        unpaired_branches = np.delete(np.arange(0, branch_start_coords.shape[0]), pairs.flatten())
        LOGGER.debug(f"Unpaired branches: {unpaired_branches}")
        # Ensure that unpaired branches start at index I where I is the number of paired branches.
        branch_label = branch_image.max()
        # Add the unpaired branches back to the branch image
        for i in unpaired_branches:
            branch_label += 1
            branch_image[ordered_branches[i][:, 0], ordered_branches[i][:, 1]] = branch_label

        return branch_image, avg_image

    @staticmethod
    def analyse_node_branches(
        p_to_nm: np.float64,
        reduced_node_area: npt.NDArray[np.int32],
        branch_start_coords: npt.NDArray[np.int32],
        max_length_px: np.float64,
        reduced_skeleton_graph: nx.classes.graph.Graph,
        image: npt.NDArray[np.number],
        average_trace_advised: bool,
        node_coord: tuple[np.int32, np.int32],
        pair_odd_branches: bool,
        filename: str,
        resolution_threshold: np.float64,
    ) -> tuple[
        npt.NDArray[np.int32],
        dict[int, MatchedBranch],
        list[npt.NDArray[np.int32]],
        dict[int, dict[str, npt.NDArray[np.bool_]]],
        npt.NDArray[np.int32],
        np.float64 | None,
    ]:
        """
        Analyse the branches of a single node.

        Parameters
        ----------
        p_to_nm : np.float64
            The pixel to nm scaling factor.
        reduced_node_area : npt.NDArray[np.int32]
            An NxM numpy array of the node in question and the branches connected to it.
            Node is marked by 3, and branches by 1.
        branch_start_coords : npt.NDArray[np.int32]
            An Nx2 numpy array of the coordinates of the branches connected to the node.
        max_length_px : np.int32
            The maximum length in pixels to traverse along while ordering.
        reduced_skeleton_graph : nx.classes.graph.Graph
            The graph representation of the reduced node area.
        image : npt.NDArray[np.number]
            The full image of the grain.
        average_trace_advised : bool
            Flag to determine whether to use the average trace.
        node_coord : tuple[np.int32, np.int32]
            The node coordinates.
        pair_odd_branches : bool
            Whether to try and pair odd-branched nodes.
        filename : str
            The filename of the image.
        resolution_threshold : np.float64
            The resolution threshold below which to warn the user that the node is difficult to analyse.

        Returns
        -------
        pairs: npt.NDArray[np.int32]
            Nx2 numpy array of pairs of branches that are matched through a node.
        matched_branches: dict[int, MatchedBranch]]
            Dictionary where the key is the index of the pair and the value is a dictionary containing the following
            keys:
            - "ordered_coords" : npt.NDArray[np.int32].
            - "heights" : npt.NDArray[np.number]. Heights of the branches.
            - "distances" : npt.NDArray[np.number]. The accumulating distance along the branch.
            - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
            - "angles" : np.float64. The angle of the branch, added in later steps.
        ordered_branches: list[npt.NDArray[np.int32]]
            List of numpy arrays of ordered branch coordinates.
        masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]]
            Dictionary where the key is the index of the pair and the value is a dictionary containing the following
            keys:
            - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
        branch_under_over_order: npt.NDArray[np.int32]
            The order of the branches based on the FWHM.
        confidence: np.float64 | None
            The confidence of the crossing. Optional.
        """
        if not p_to_nm <= resolution_threshold:
            LOGGER.debug(f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse.")

        # Pixel-wise order the branches coming from the node and calculate the starting vector for each branch
        ordered_branches, singlet_branch_vectors = nodeStats.get_ordered_branches_and_vectors(
            reduced_node_area, branch_start_coords, max_length_px
        )

        # Pair the singlet branch vectors based on their suitability using vector orientation.
        if len(branch_start_coords) % 2 == 0 or pair_odd_branches:
            pairs = nodeStats.pair_vectors(np.asarray(singlet_branch_vectors))
        else:
            pairs = np.array([], dtype=np.int32)

        # Match the branches up
        matched_branches, masked_image = nodeStats.join_matching_branches_through_node(
            pairs,
            ordered_branches,
            reduced_skeleton_graph,
            image,
            average_trace_advised,
            node_coord,
            filename,
        )

        # Redo the FWHMs after the processing for more accurate determination of under/overs.
        hms = []
        for _, values in matched_branches.items():
            hms.append(values["fwhm"]["half_maxs"][2])
        for _, values in matched_branches.items():
            values["fwhm"] = nodeStats.calculate_fwhm(values["heights"], values["distances"], hm=max(hms))

        # Get the confidence of the crossing
        crossing_fwhms = []
        for _, values in matched_branches.items():
            crossing_fwhms.append(values["fwhm"]["fwhm"])
        if len(crossing_fwhms) <= 1:
            confidence = None
        else:
            crossing_fwhm_combinations = list(combinations(crossing_fwhms, 2))
            confidence = np.float64(nodeStats.cross_confidence(crossing_fwhm_combinations))

        # Order the branch indexes based on the FWHM of the branches.
        branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(crossing_fwhms))]

        return (
            pairs,
            matched_branches,
            ordered_branches,
            masked_image,
            branch_under_over_order,
            confidence,
            singlet_branch_vectors,
        )

    @staticmethod
    def join_matching_branches_through_node(
        pairs: npt.NDArray[np.int32],
        ordered_branches: list[npt.NDArray[np.int32]],
        reduced_skeleton_graph: nx.classes.graph.Graph,
        image: npt.NDArray[np.number],
        average_trace_advised: bool,
        node_coords: tuple[np.int32, np.int32],
        filename: str,
    ) -> tuple[dict[int, MatchedBranch], dict[int, dict[str, npt.NDArray[np.bool_]]]]:
        """
        Join branches that are matched through a node.

        Parameters
        ----------
        pairs : npt.NDArray[np.int32]
            Nx2 numpy array of pairs of branches that are matched through a node.
        ordered_branches : list[npt.NDArray[np.int32]]
            List of numpy arrays of ordered branch coordinates.
        reduced_skeleton_graph : nx.classes.graph.Graph
            Graph representation of the skeleton.
        image : npt.NDArray[np.number]
            The full image of the grain.
        average_trace_advised : bool
            Flag to determine whether to use the average trace.
        node_coords : tuple[np.int32, np.int32]
            The node coordinates.
        filename : str
            The filename of the image.

        Returns
        -------
        matched_branches: dict[int, dict[str, npt.NDArray[np.number]]]
            Dictionary where the key is the index of the pair and the value is a dictionary containing the following
            keys:
            - "ordered_coords" : npt.NDArray[np.int32].
            - "heights" : npt.NDArray[np.number]. Heights of the branches.
            - "distances" :
            - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
        masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]]
            Dictionary where the key is the index of the pair and the value is a dictionary containing the following
            keys:
            - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
        """
        matched_branches: dict[int, MatchedBranch] = {}
        masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = (
            {}
        )  # Masked image is a dictionary of pairs of branches
        for i, (branch_1, branch_2) in enumerate(pairs):
            matched_branches[i] = MatchedBranch(
                ordered_coords=np.array([], dtype=np.int32),
                heights=np.array([], dtype=np.float64),
                distances=np.array([], dtype=np.float64),
                fwhm={},
                angles=None,
            )
            masked_image[i] = {}
            # find close ends by rearranging branch coords
            branch_1_coords, branch_2_coords = nodeStats.order_branches(
                ordered_branches[branch_1], ordered_branches[branch_2]
            )
            # Get graphical shortest path between branch ends on the skeleton
            crossing = nx.shortest_path(
                reduced_skeleton_graph,
                source=tuple(branch_1_coords[-1]),
                target=tuple(branch_2_coords[0]),
                weight="weight",
            )
            crossing = np.asarray(crossing[1:-1])  # remove start and end points & turn into array
            # Branch coords and crossing
            if crossing.shape == (0,):
                branch_coords = np.vstack([branch_1_coords, branch_2_coords])
            else:
                branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords])
            # make images of single branch joined and multiple branches joined
            single_branch_img: npt.NDArray[np.bool_] = np.zeros_like(image).astype(bool)
            single_branch_img[branch_coords[:, 0], branch_coords[:, 1]] = True
            single_branch_coords = order_branch(single_branch_img.astype(bool), [0, 0])
            # calc image-wide coords
            matched_branches[i]["ordered_coords"] = single_branch_coords
            # get heights and trace distance of branch
            try:
                assert average_trace_advised
                distances, heights, mask, _ = nodeStats.average_height_trace(
                    image, single_branch_img, single_branch_coords, [node_coords[0], node_coords[1]]
                )
                masked_image[i]["avg_mask"] = mask
            except (
                AssertionError,
                IndexError,
            ) as e:  # Assertion - avg trace not advised, Index - wiggy branches
                LOGGER.debug(f"[{filename}] : avg trace failed with {e}, single trace only.")
                average_trace_advised = False
                distances = nodeStats.coord_dist_rad(single_branch_coords, np.array([node_coords[0], node_coords[1]]))
                # distances = self.coord_dist(single_branch_coords)
                zero_dist = distances[
                    np.argmin(
                        np.sqrt(
                            (single_branch_coords[:, 0] - node_coords[0]) ** 2
                            + (single_branch_coords[:, 1] - node_coords[1]) ** 2
                        )
                    )
                ]
                heights = image[single_branch_coords[:, 0], single_branch_coords[:, 1]]  # self.hess
                distances = distances - zero_dist
                distances, heights = nodeStats.average_uniques(
                    distances, heights
                )  # needs to be paired with coord_dist_rad
            matched_branches[i]["heights"] = heights
            matched_branches[i]["distances"] = distances
            # identify over/under
            matched_branches[i]["fwhm"] = nodeStats.calculate_fwhm(heights, distances)

        return matched_branches, masked_image

    @staticmethod
    def get_ordered_branches_and_vectors(
        reduced_node_area: npt.NDArray[np.int32],
        branch_start_coords: npt.NDArray[np.int32],
        max_length_px: np.float64,
    ) -> tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]:
        """
        Get ordered branches and vectors for a node.

        Branches are ordered so they are no longer just a disordered set of coordinates, and vectors are calculated to
        represent the general direction tendency of the branch, this allows for alignment matching later on.

        Parameters
        ----------
        reduced_node_area : npt.NDArray[np.int32]
            An NxM numpy array of the node in question and the branches connected to it.
            Node is marked by 3, and branches by 1.
        branch_start_coords : npt.NDArray[np.int32]
            An Px2 numpy array of coordinates representing the start of branches where P is the number of branches.
        max_length_px : np.int32
            The maximum length in pixels to traverse along while ordering.

        Returns
        -------
        tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]
            A tuple containing a list of ordered branches and a list of vectors.
        """
        ordered_branches = []
        vectors = []
        nodeless = np.where(reduced_node_area == 1, 1, 0)
        for branch_start_coord in branch_start_coords:
            # Order the branch coordinates so they're no longer just a disordered set of coordinates
            ordered_branch = order_branch_from_start(nodeless.copy(), branch_start_coord, max_length=max_length_px)
            ordered_branches.append(ordered_branch)

            # Calculate vector to represent the general direction tendency of the branch (for alignment matching)
            vector = nodeStats.get_vector(ordered_branch, branch_start_coord)
            vectors.append(vector)

        return ordered_branches, vectors

    @staticmethod
    def cross_confidence(pair_combinations: list) -> float:
        """
        Obtain the average confidence of the combinations using a reciprical function.

        Parameters
        ----------
        pair_combinations : list
            List of length 2 combinations of FWHM values.

        Returns
        -------
        float
            The average crossing confidence.
        """
        c = 0
        for pair in pair_combinations:
            c += nodeStats.recip(pair)
        return c / len(pair_combinations)

    @staticmethod
    def recip(vals: list) -> float:
        """
        Compute 1 - (max / min) of the two values provided.

        Parameters
        ----------
        vals : list
            List of 2 values.

        Returns
        -------
        float
            Result of applying the 1-(min / max) function to the two values.
        """
        try:
            if min(vals) == 0:  # means fwhm variation hasn't worked
                return 0
            return 1 - min(vals) / max(vals)
        except ZeroDivisionError:
            return 0

    @staticmethod
    def get_vector(coords: npt.NDArray, origin: npt.NDArray) -> npt.NDArray:
        """
        Calculate the normalised vector of the coordinate means in a branch.

        Parameters
        ----------
        coords : npt.NDArray
            2xN array of x, y coordinates.
        origin : npt.NDArray
            2x1 array of an x, y coordinate.

        Returns
        -------
        npt.NDArray
            Normalised vector from origin to the mean coordinate.
        """
        vector = coords.mean(axis=0) - origin
        norm = np.sqrt(vector @ vector)
        return vector if norm == 0 else vector / norm  # normalise vector so length=1

    @staticmethod
    def calc_angles(vectors: npt.NDArray) -> npt.NDArray[np.float64]:
        """
        Calculate the angles between vectors in an array.

        Uses the formula:

        .. code-block:: RST

            cos(theta) = |a|•|b|/|a||b|

        Parameters
        ----------
        vectors : npt.NDArray
            Array of 2x1 vectors.

        Returns
        -------
        npt.NDArray
            An array of the cosine of the angles between the vectors.
        """
        dot = vectors @ vectors.T
        norm = np.diag(dot) ** 0.5
        cos_angles = dot / (norm.reshape(-1, 1) @ norm.reshape(1, -1))
        np.fill_diagonal(cos_angles, 1)  # ensures vector_x • vector_x angles are 0
        return abs(np.arccos(cos_angles) / np.pi * 180)  # angles in degrees

    @staticmethod
    def pair_vectors(vectors: npt.NDArray) -> npt.NDArray[np.int32]:
        """
        Take a list of vectors and pairs them based on the angle between them.

        Parameters
        ----------
        vectors : npt.NDArray
            Array of 2x1 vectors to be paired.

        Returns
        -------
        npt.NDArray
            An array of the matching pair indices.
        """
        # calculate cosine of angle
        angles = nodeStats.calc_angles(vectors)
        # match angles
        return nodeStats.best_matches(angles)

    @staticmethod
    def best_matches(arr: npt.NDArray, max_weight_matching: bool = True) -> npt.NDArray:
        """
        Turn a matrix into a graph and calculates the best matching index pairs.

        Parameters
        ----------
        arr : npt.NDArray
            Transpose symmetric MxM array where the value of index i, j represents a weight between i and j.
        max_weight_matching : bool
            Whether to obtain best matching pairs via maximum weight, or minimum weight matching.

        Returns
        -------
        npt.NDArray
            Array of pairs of indexes.
        """
        if max_weight_matching:
            G = nodeStats.create_weighted_graph(arr)
            matching = np.array(list(nx.max_weight_matching(G, maxcardinality=True)))
        else:
            np.fill_diagonal(arr, arr.max() + 1)
            G = nodeStats.create_weighted_graph(arr)
            matching = np.array(list(nx.min_weight_matching(G)))
        return matching

    @staticmethod
    def create_weighted_graph(matrix: npt.NDArray) -> nx.Graph:
        """
        Create a bipartite graph connecting i <-> j from a square matrix of weights matrix[i, j].

        Parameters
        ----------
        matrix : npt.NDArray
            Square array of weights between rows and columns.

        Returns
        -------
        nx.Graph
            Bipatrite graph with edge weight i->j matching matrix[i,j].
        """
        n = len(matrix)
        G = nx.Graph()
        for i in range(n):
            for j in range(i + 1, n):
                G.add_edge(i, j, weight=matrix[i, j])
        return G

    @staticmethod
    def pair_angles(angles: npt.NDArray) -> list:
        """
        Pair angles that are 180 degrees to each other and removes them before selecting the next pair.

        Parameters
        ----------
        angles : npt.NDArray
             Square array (i,j) of angles between i and j.

        Returns
        -------
        list
             A list of paired indexes in a list.
        """
        angles_cp = angles.copy()
        pairs = []
        for _ in range(int(angles.shape[0] / 2)):
            pair = np.unravel_index(np.argmax(angles_cp), angles.shape)
            pairs.append(pair)  # add to list
            angles_cp[[pair]] = 0  # set rows 0 to avoid picking again
            angles_cp[:, [pair]] = 0  # set cols 0 to avoid picking again

        return np.asarray(pairs)

    @staticmethod
    def gaussian(x: npt.NDArray, h: float, mean: float, sigma: float):
        """
        Apply the gaussian function.

        Parameters
        ----------
        x : npt.NDArray
            X values to be passed into the gaussian.
        h : float
            The peak height of the gaussian.
        mean : float
            The mean of the x values.
        sigma : float
            The standard deviation of the image.

        Returns
        -------
        npt.NDArray
            The y-values of the gaussian performed on the x values.
        """
        return h * np.exp(-((x - mean) ** 2) / (2 * sigma**2))

    @staticmethod
    def interpolate_between_yvalue(x: npt.NDArray, y: npt.NDArray, yvalue: float) -> float:
        """
        Calculate the x value between the two points either side of yvalue in y.

        Parameters
        ----------
        x : npt.NDArray
            An array of length y.
        y : npt.NDArray
            An array of length x.
        yvalue : float
            A value within the bounds of the y array.

        Returns
        -------
        float
            The linearly interpolated x value between the arrays.
        """
        for i in range(len(y) - 1):
            if y[i] <= yvalue <= y[i + 1] or y[i + 1] <= yvalue <= y[i]:  # if points cross through the hm value
                return nodeStats.lin_interp([x[i], y[i]], [x[i + 1], y[i + 1]], yvalue=yvalue)
        return 0

    @staticmethod
    def calculate_fwhm(
        heights: npt.NDArray, distances: npt.NDArray, hm: float | None = None
    ) -> dict[str, np.float64 | list[np.float64 | float | None]]:
        """
        Calculate the FWHM value.

        First identifyies the HM then finding the closest values in the distances array and using
        linear interpolation to calculate the FWHM.

        Parameters
        ----------
        heights : npt.NDArray
            Array of heights.
        distances : npt.NDArray
            Array of distances.
        hm : Union[None, float], optional
            The halfmax value to match (if wanting the same HM between curves), by default None.

        Returns
        -------
        tuple[float, list, list]
            The FWHM value, [distance at hm for 1st half of trace, distance at hm for 2nd half of trace,
            HM value], [index of the highest point, distance at highest point, height at highest point].
        """
        centre_fraction = int(len(heights) * 0.2)  # in case zone approaches another node, look around centre for max
        if centre_fraction == 0:
            high_idx = np.argmax(heights)
        else:
            high_idx = np.argmax(heights[centre_fraction:-centre_fraction]) + centre_fraction
        # get array halves to find first points that cross hm
        arr1 = heights[:high_idx][::-1]
        dist1 = distances[:high_idx][::-1]
        arr2 = heights[high_idx:]
        dist2 = distances[high_idx:]
        if hm is None:
            # Get half max
            hm = (heights.max() - heights.min()) / 2 + heights.min()
            # half max value -> try to make it the same as other crossing branch?
            # increase make hm = lowest of peak if it doesn’t hit one side
            if np.min(arr1) > hm:
                arr1_local_min = argrelextrema(arr1, np.less)[-1]  # closest to end
                try:
                    hm = arr1[arr1_local_min][0]
                except IndexError:  # index error when no local minima
                    hm = np.min(arr1)
            elif np.min(arr2) > hm:
                arr2_local_min = argrelextrema(arr2, np.less)[0]  # closest to start
                try:
                    hm = arr2[arr2_local_min][0]
                except IndexError:  # index error when no local minima
                    hm = np.min(arr2)
        arr1_hm = nodeStats.interpolate_between_yvalue(x=dist1, y=arr1, yvalue=hm)
        arr2_hm = nodeStats.interpolate_between_yvalue(x=dist2, y=arr2, yvalue=hm)
        fwhm = np.float64(abs(arr2_hm - arr1_hm))
        return {
            "fwhm": fwhm,
            "half_maxs": [arr1_hm, arr2_hm, hm],
            "peaks": [high_idx, distances[high_idx], heights[high_idx]],
        }

    @staticmethod
    def lin_interp(point_1: list, point_2: list, xvalue: float | None = None, yvalue: float | None = None) -> float:
        """
        Linear interp 2 points by finding line equation and subbing.

        Parameters
        ----------
        point_1 : list
            List of an x and y coordinate.
        point_2 : list
            List of an x and y coordinate.
        xvalue : Union[float, None], optional
            Value at which to interpolate to get a y coordinate, by default None.
        yvalue : Union[float, None], optional
            Value at which to interpolate to get an x coordinate, by default None.

        Returns
        -------
        float
            Value of x or y linear interpolation.
        """
        m = (point_1[1] - point_2[1]) / (point_1[0] - point_2[0])
        c = point_1[1] - (m * point_1[0])
        if xvalue is not None:
            return m * xvalue + c  # interp_y
        if yvalue is not None:
            return (yvalue - c) / m  # interp_x
        raise ValueError

    @staticmethod
    def order_branches(branch1: npt.NDArray, branch2: npt.NDArray) -> tuple:
        """
        Order the two ordered arrays based on the closest endpoint coordinates.

        Parameters
        ----------
        branch1 : npt.NDArray
            An Nx2 array describing coordinates.
        branch2 : npt.NDArray
            An Nx2 array describing coordinates.

        Returns
        -------
        tuple
            An tuple with the each coordinate array ordered to follow on from one-another.
        """
        endpoints1 = np.asarray([branch1[0], branch1[-1]])
        endpoints2 = np.asarray([branch2[0], branch2[-1]])
        sum1 = abs(endpoints1 - endpoints2).sum(axis=1)
        sum2 = abs(endpoints1[::-1] - endpoints2).sum(axis=1)
        if sum1.min() < sum2.min():
            if np.argmin(sum1) == 0:
                return branch1[::-1], branch2
            return branch1, branch2[::-1]
        if np.argmin(sum2) == 0:
            return branch1, branch2
        return branch1[::-1], branch2[::-1]

    @staticmethod
    def binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray:
        """
        Create a binary path following the straight line between 2 points.

        Parameters
        ----------
        start : npt.NDArray
            A coordinate.
        end : npt.NDArray
            Another coordinate.

        Returns
        -------
        npt.NDArray
            An Nx2 coordinate array that the line passes through.
        """
        arr = []
        m_swap = False
        x_swap = False
        slope = (end - start)[1] / (end - start)[0]

        if abs(slope) > 1:  # swap x and y if slope will cause skips
            start, end = start[::-1], end[::-1]
            slope = 1 / slope
            m_swap = True

        if start[0] > end[0]:  # swap x coords if coords wrong way around
            start, end = end, start
            x_swap = True

        # code assumes slope < 1 hence swap
        x_start, y_start = start
        x_end, _ = end
        for x in range(x_start, x_end + 1):
            y_true = slope * (x - x_start) + y_start
            y_pixel = np.round(y_true)
            arr.append([x, y_pixel])

        if m_swap:  # if swapped due to slope, return
            arr = np.asarray(arr)[:, [1, 0]].reshape(-1, 2).astype(int)
            if x_swap:
                return arr[::-1]
            return arr
        arr = np.asarray(arr).reshape(-1, 2).astype(int)
        if x_swap:
            return arr[::-1]
        return arr

    @staticmethod
    def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray:
        """
        Calculate the distance from the centre coordinate to a point along the ordered coordinates.

        This differs to traversal along the coordinates taken. This also averages any common distance
        values and makes those in the trace before the node index negative.

        Parameters
        ----------
        coords : npt.NDArray
            Nx2 array of branch coordinates.
        centre : npt.NDArray
            A 1x2 array of the centre coordinates to identify a 0 point for the node.
        pixel_to_nm_scaling : float, optional
            The pixel to nanometer scaling factor to provide real units, by default 1.

        Returns
        -------
        npt.NDArray
            A Nx1 array of the distance from the node centre.
        """
        diff_coords = coords - centre
        if np.all(coords == centre, axis=1).sum() == 0:  # if centre not in coords, reassign centre
            diff_dists = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2)
            centre = coords[np.argmin(diff_dists)]
        cross_idx = np.argwhere(np.all(coords == centre, axis=1))
        rad_dist = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2)
        rad_dist[0 : cross_idx[0][0]] *= -1
        return rad_dist * pixel_to_nm_scaling

    @staticmethod
    def above_below_value_idx(array: npt.NDArray, value: float) -> list:
        """
        Identify indices of the array neighbouring the specified value.

        Parameters
        ----------
        array : npt.NDArray
            Array of values.
        value : float
            Value to identify indices between.

        Returns
        -------
        list
            List of the lower index and higher index around the value.

        Raises
        ------
        IndexError
            When the value is in the array.
        """
        idx1 = abs(array - value).argmin()
        try:
            if array[idx1] < value < array[idx1 + 1]:
                idx2 = idx1 + 1
            elif array[idx1 - 1] < value < array[idx1]:
                idx2 = idx1 - 1
            else:
                raise IndexError  # this will be if the number is the same
            indices = [idx1, idx2]
            indices.sort()
            return indices
        except IndexError:
            return None

    @staticmethod
    def average_height_trace(  # noqa: C901
        img: npt.NDArray, branch_mask: npt.NDArray, branch_coords: npt.NDArray, centre=(0, 0)
    ) -> tuple:
        """
        Average two side-by-side ordered skeleton distance and height traces.

        Dilate the original branch to create two additional side-by-side branches
        in order to get a more accurate average of the height traces. This function produces
        the common distances between these 3 branches, and their averaged heights.

        Parameters
        ----------
        img : npt.NDArray
            An array of numbers pertaining to an image.
        branch_mask : npt.NDArray
            A binary array of the branch, must share the same dimensions as the image.
        branch_coords : npt.NDArray
            Ordered coordinates of the branch mask.
        centre : Union[float, None]
            The coordinates to centre the branch around.

        Returns
        -------
        tuple
            A tuple of the averaged heights from the linetrace and their corresponding distances
            from the crossing.
        """
        # get heights and dists of the original (middle) branch
        branch_dist = nodeStats.coord_dist_rad(branch_coords, centre)
        # branch_dist = self.coord_dist(branch_coords)
        branch_heights = img[branch_coords[:, 0], branch_coords[:, 1]]
        branch_dist, branch_heights = nodeStats.average_uniques(
            branch_dist, branch_heights
        )  # needs to be paired with coord_dist_rad
        dist_zero_point = branch_dist[
            np.argmin(np.sqrt((branch_coords[:, 0] - centre[0]) ** 2 + (branch_coords[:, 1] - centre[1]) ** 2))
        ]
        branch_dist_norm = branch_dist - dist_zero_point  # - 0  # branch_dist[branch_heights.argmax()]

        # want to get a 3 pixel line trace, one on each side of orig
        dilate = binary_dilation(branch_mask, iterations=1)
        dilate = nodeStats.fill_holes(dilate)
        dilate_minus = np.where(dilate != branch_mask, 1, 0)
        dilate2 = binary_dilation(dilate, iterations=1)
        dilate2[(dilate == 1) | (branch_mask == 1)] = 0
        labels = label(dilate2)
        # Cleanup stages - re-entering, early terminating, closer traces
        #   if parallel trace out and back in zone, can get > 2 labels
        labels = nodeStats._remove_re_entering_branches(labels, remaining_branches=2)
        #   if parallel trace doesn't exit window, can get 1 label
        #       occurs when skeleton has poor connections (extra branches which cut corners)
        if labels.max() == 1:
            conv = convolve_skeleton(branch_mask)
            endpoints = np.argwhere(conv == 2)
            for endpoint in endpoints:  # may be >1 endpoint
                para_trace_coords = np.argwhere(labels == 1)
                abs_diff = np.absolute(para_trace_coords - endpoint).sum(axis=1)
                min_idxs = np.where(abs_diff == abs_diff.min())
                trace_coords_remove = para_trace_coords[min_idxs]
                labels[trace_coords_remove[:, 0], trace_coords_remove[:, 1]] = 0
            labels = label(labels)
        #   reduce binary dilation distance
        parallel = np.zeros_like(branch_mask).astype(np.int32)
        for i in range(1, labels.max() + 1):
            single = labels.copy()
            single[single != i] = 0
            single[single == i] = 1
            sing_dil = binary_dilation(single)
            parallel[(sing_dil == dilate_minus) & (sing_dil == 1)] = i
        labels = parallel.copy()

        binary = labels.copy()
        binary[binary != 0] = 1
        binary += branch_mask

        # get and order coords, then get heights and distances relitive to node centre / highest point
        heights = []
        distances = []
        for i in np.unique(labels)[1:]:
            trace_img = np.where(labels == i, 1, 0)
            trace_img = getSkeleton(img, trace_img, method="zhang").get_skeleton()
            trace = order_branch(trace_img, branch_coords[0])
            height_trace = img[trace[:, 0], trace[:, 1]]
            dist = nodeStats.coord_dist_rad(trace, centre)  # self.coord_dist(trace)
            dist, height_trace = nodeStats.average_uniques(dist, height_trace)  # needs to be paired with coord_dist_rad
            heights.append(height_trace)
            distances.append(
                dist - dist_zero_point  # - 0
            )  # branch_dist[branch_heights.argmax()]) #dist[central_heights.argmax()])
        # Make like coord system using original branch
        avg1 = []
        avg2 = []
        for mid_dist in branch_dist_norm:
            for i, (distance, height) in enumerate(zip(distances, heights)):
                # check if distance already in traces array
                if (mid_dist == distance).any():
                    idx = np.where(mid_dist == distance)
                    if i == 0:
                        avg1.append([mid_dist, height[idx][0]])
                    else:
                        avg2.append([mid_dist, height[idx][0]])
                # if not, linearly interpolate the mid-branch value
                else:
                    # get index after and before the mid branches' x coord
                    xidxs = nodeStats.above_below_value_idx(distance, mid_dist)
                    if xidxs is None:
                        pass  # if indexes outside of range, pass
                    else:
                        point1 = [distance[xidxs[0]], height[xidxs[0]]]
                        point2 = [distance[xidxs[1]], height[xidxs[1]]]
                        y = nodeStats.lin_interp(point1, point2, xvalue=mid_dist)
                        if i == 0:
                            avg1.append([mid_dist, y])
                        else:
                            avg2.append([mid_dist, y])
        avg1 = np.asarray(avg1)
        avg2 = np.asarray(avg2)
        # ensure arrays are same length to average
        temp_x = branch_dist_norm[np.isin(branch_dist_norm, avg1[:, 0])]
        common_dists = avg2[:, 0][np.isin(avg2[:, 0], temp_x)]

        common_avg_branch_heights = branch_heights[np.isin(branch_dist_norm, common_dists)]
        common_avg1_heights = avg1[:, 1][np.isin(avg1[:, 0], common_dists)]
        common_avg2_heights = avg2[:, 1][np.isin(avg2[:, 0], common_dists)]

        average_heights = (common_avg_branch_heights + common_avg1_heights + common_avg2_heights) / 3
        return (
            common_dists,
            average_heights,
            binary,
            [[heights[0], branch_heights, heights[1]], [distances[0], branch_dist_norm, distances[1]]],
        )

    @staticmethod
    def fill_holes(mask: npt.NDArray) -> npt.NDArray:
        """
        Fill all holes within a binary mask.

        Parameters
        ----------
        mask : npt.NDArray
            Binary array of object.

        Returns
        -------
        npt.NDArray
            Binary array of object with any interior holes filled in.
        """
        inv_mask = np.where(mask != 0, 0, 1)
        lbl_inv = label(inv_mask, connectivity=1)
        idxs, counts = np.unique(lbl_inv, return_counts=True)
        max_idx = idxs[np.argmax(counts)]
        return np.where(lbl_inv != max_idx, 1, 0)

    @staticmethod
    def _remove_re_entering_branches(mask: npt.NDArray, remaining_branches: int = 1) -> npt.NDArray:
        """
        Remove smallest branches which branches exit and re-enter the viewing area.

        Contninues until only <remaining_branches> remain.

        Parameters
        ----------
        mask : npt.NDArray
            Skeletonised binary mask of an object.
        remaining_branches : int, optional
            Number of objects (branches) to keep, by default 1.

        Returns
        -------
        npt.NDArray
            Mask with only a single skeletonised branch.
        """
        rtn_image = mask.copy()
        binary_image = mask.copy()
        binary_image[binary_image != 0] = 1
        labels = label(binary_image)

        if labels.max() > remaining_branches:
            lens = [labels[labels == i].size for i in range(1, labels.max() + 1)]
            while len(lens) > remaining_branches:
                smallest_idx = min(enumerate(lens), key=lambda x: x[1])[0]
                rtn_image[labels == smallest_idx + 1] = 0
                lens.remove(min(lens))

        return rtn_image

    @staticmethod
    def only_centre_branches(node_image: npt.NDArray, node_coordinate: npt.NDArray) -> npt.NDArray[np.int32]:
        """
        Remove all branches not connected to the current node.

        Parameters
        ----------
        node_image : npt.NDArray
            An image of the skeletonised area surrounding the node where
            the background = 0, skeleton = 1, termini = 2, nodes = 3.
        node_coordinate : npt.NDArray
            2x1 coordinate describing the position of a node.

        Returns
        -------
        npt.NDArray[np.int32]
            The initial node image but only with skeletal branches
            connected to the middle node.
        """
        node_image_cp = node_image.copy()

        # get node-only image
        nodes = node_image_cp.copy()
        nodes[nodes != 3] = 0
        labeled_nodes = label(nodes)

        # find which cluster is closest to the centre
        node_coords = np.argwhere(nodes == 3)
        min_coords = node_coords[abs(node_coords - node_coordinate).sum(axis=1).argmin()]
        centre_idx = labeled_nodes[min_coords[0], min_coords[1]]

        # get nodeless image
        nodeless = node_image_cp.copy()
        nodeless = np.where(
            (node_image == 1) | (node_image == 2), 1, 0
        )  # if termini, need this in the labeled branches too
        nodeless[labeled_nodes == centre_idx] = 1  # return centre node
        labeled_nodeless = label(nodeless)

        # apply to return image
        for i in range(1, labeled_nodeless.max() + 1):
            if (node_image_cp[labeled_nodeless == i] == 3).any():
                node_image_cp[labeled_nodeless != i] = 0
                break

        # remove small area around other nodes
        labeled_nodes[labeled_nodes == centre_idx] = 0
        non_central_node_coords = np.argwhere(labeled_nodes != 0)
        for coord in non_central_node_coords:
            for j, coord_val in enumerate(coord):
                if coord_val - 1 < 0:
                    coord[j] = 1
                if coord_val + 2 > node_image_cp.shape[j]:
                    coord[j] = node_image_cp.shape[j] - 2
            node_image_cp[coord[0] - 1 : coord[0] + 2, coord[1] - 1 : coord[1] + 2] = 0

        return node_image_cp

    @staticmethod
    def average_uniques(arr1: npt.NDArray, arr2: npt.NDArray) -> tuple:
        """
        Obtain the unique values of both arrays, and the average of common values.

        Parameters
        ----------
        arr1 : npt.NDArray
            An array.
        arr2 : npt.NDArray
            An array.

        Returns
        -------
        tuple
            The unique values of both arrays, and the averaged common values.
        """
        arr1_uniq, index = np.unique(arr1, return_index=True)
        arr2_new = np.zeros_like(arr1_uniq).astype(np.float64)
        for i, val in enumerate(arr1[index]):
            mean = arr2[arr1 == val].mean()
            arr2_new[i] += mean

        return arr1[index], arr2_new

    @staticmethod
    def average_crossing_confs(node_dict) -> None | float:
        """
        Return the average crossing confidence of all crossings in the molecule.

        Parameters
        ----------
        node_dict : dict
            A dictionary containing node statistics and information.

        Returns
        -------
        Union[None, float]
            The value of minimum confidence or none if not possible.
        """
        sum_conf = 0
        valid_confs = 0
        for _, (_, values) in enumerate(node_dict.items()):
            confidence = values["confidence"]
            if confidence is not None:
                sum_conf += confidence
                valid_confs += 1
        try:
            return sum_conf / valid_confs
        except ZeroDivisionError:
            return None

    @staticmethod
    def minimum_crossing_confs(node_dict: dict) -> None | float:
        """
        Return the minimum crossing confidence of all crossings in the molecule.

        Parameters
        ----------
        node_dict : dict
            A dictionary containing node statistics and information.

        Returns
        -------
        Union[None, float]
            The value of minimum confidence or none if not possible.
        """
        confidences = []
        valid_confs = 0
        for _, (_, values) in enumerate(node_dict.items()):
            confidence = values["confidence"]
            if confidence is not None:
                confidences.append(confidence)
                valid_confs += 1
        try:
            return min(confidences)
        except ValueError:
            return None

    def compile_metrics(self) -> None:
        """Add the number of crossings, average and minimum crossing confidence to the metrics dictionary."""
        self.metrics["num_crossings"] = np.int64((self.node_centre_mask == 3).sum())
        self.metrics["avg_crossing_confidence"] = np.float64(nodeStats.average_crossing_confs(self.node_dicts))
        self.metrics["min_crossing_confidence"] = np.float64(nodeStats.minimum_crossing_confs(self.node_dicts))

__init__(filename: str, image: npt.NDArray, mask: npt.NDArray, smoothed_mask: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: np.float64, n_grain: int, node_joining_length: float, node_extend_dist: float, branch_pairing_length: float, pair_odd_branches: bool) -> None

Initialise the nodeStats class.

Parameters:

Name Type Description Default
filename str

The name of the file being processed. For logging purposes.

required
image NDArray

The array of pixels.

required
mask NDArray

The binary segmentation mask.

required
smoothed_mask NDArray

A smoothed version of the bianary segmentation mask.

required
skeleton NDArray

A binary single-pixel wide mask of objects in the 'image'.

required
pixel_to_nm_scaling float

The pixel to nm scaling factor.

required
n_grain int

The grain number.

required
node_joining_length float

The length over which to join skeletal intersections to be counted as one crossing.

required
node_joining_length float

The distance over which to join nearby odd-branched nodes.

required
node_extend_dist float

The distance under which to join odd-branched node regions.

required
branch_pairing_length float

The length from the crossing point to pair and trace, obtaining FWHM's.

required
pair_odd_branches bool

Whether to try and pair odd-branched nodes.

required
Source code in topostats\tracing\nodestats.py
def __init__(
    self,
    filename: str,
    image: npt.NDArray,
    mask: npt.NDArray,
    smoothed_mask: npt.NDArray,
    skeleton: npt.NDArray,
    pixel_to_nm_scaling: np.float64,
    n_grain: int,
    node_joining_length: float,
    node_extend_dist: float,
    branch_pairing_length: float,
    pair_odd_branches: bool,
) -> None:
    """
    Initialise the nodeStats class.

    Parameters
    ----------
    filename : str
        The name of the file being processed. For logging purposes.
    image : npt.NDArray
        The array of pixels.
    mask : npt.NDArray
        The binary segmentation mask.
    smoothed_mask : npt.NDArray
        A smoothed version of the bianary segmentation mask.
    skeleton : npt.NDArray
        A binary single-pixel wide mask of objects in the 'image'.
    pixel_to_nm_scaling : float
        The pixel to nm scaling factor.
    n_grain : int
        The grain number.
    node_joining_length : float
        The length over which to join skeletal intersections to be counted as one crossing.
    node_joining_length : float
        The distance over which to join nearby odd-branched nodes.
    node_extend_dist : float
        The distance under which to join odd-branched node regions.
    branch_pairing_length : float
        The length from the crossing point to pair and trace, obtaining FWHM's.
    pair_odd_branches : bool
        Whether to try and pair odd-branched nodes.
    """
    self.filename = filename
    self.image = image
    self.mask = mask
    self.smoothed_mask = smoothed_mask  # only used to average traces
    self.skeleton = skeleton
    self.pixel_to_nm_scaling = pixel_to_nm_scaling
    self.n_grain = n_grain
    self.node_joining_length = node_joining_length
    self.node_extend_dist = node_extend_dist / self.pixel_to_nm_scaling
    self.branch_pairing_length = branch_pairing_length
    self.pair_odd_branches = pair_odd_branches

    self.conv_skelly = np.zeros_like(self.skeleton)
    self.connected_nodes = np.zeros_like(self.skeleton)
    self.all_connected_nodes = np.zeros_like(self.skeleton)
    self.whole_skel_graph: nx.classes.graph.Graph | None = None
    self.node_centre_mask = np.zeros_like(self.skeleton)

    self.metrics = {
        "num_crossings": np.int64(0),
        "avg_crossing_confidence": None,
        "min_crossing_confidence": None,
    }

    self.node_dicts: dict[str, NodeDict] = {}
    self.image_dict: ImageDict = {
        "nodes": {},
        "grain": {
            "grain_image": self.image,
            "grain_mask": self.mask,
            "grain_skeleton": self.skeleton,
        },
    }

    self.full_dict = {}
    self.mol_coords = {}
    self.visuals = {}
    self.all_visuals_img = None

above_below_value_idx(array: npt.NDArray, value: float) -> list staticmethod

Identify indices of the array neighbouring the specified value.

Parameters:

Name Type Description Default
array NDArray

Array of values.

required
value float

Value to identify indices between.

required

Returns:

Type Description
list

List of the lower index and higher index around the value.

Raises:

Type Description
IndexError

When the value is in the array.

Source code in topostats\tracing\nodestats.py
@staticmethod
def above_below_value_idx(array: npt.NDArray, value: float) -> list:
    """
    Identify indices of the array neighbouring the specified value.

    Parameters
    ----------
    array : npt.NDArray
        Array of values.
    value : float
        Value to identify indices between.

    Returns
    -------
    list
        List of the lower index and higher index around the value.

    Raises
    ------
    IndexError
        When the value is in the array.
    """
    idx1 = abs(array - value).argmin()
    try:
        if array[idx1] < value < array[idx1 + 1]:
            idx2 = idx1 + 1
        elif array[idx1 - 1] < value < array[idx1]:
            idx2 = idx1 - 1
        else:
            raise IndexError  # this will be if the number is the same
        indices = [idx1, idx2]
        indices.sort()
        return indices
    except IndexError:
        return None

add_branches_to_labelled_image(branch_under_over_order: npt.NDArray[np.int32], matched_branches: dict[int, MatchedBranch], masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]], branch_start_coords: npt.NDArray[np.int32], ordered_branches: list[npt.NDArray[np.int32]], pairs: npt.NDArray[np.int32], average_trace_advised: bool, image_shape: tuple[int, int]) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]] staticmethod

Add branches to a labelled image.

Parameters:

Name Type Description Default
branch_under_over_order NDArray[int32]

The order of the branches.

required
matched_branches dict[int, dict[str, MatchedBranch]]

Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - "distances" : - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.

required
masked_image dict[int, dict[str, NDArray[bool_]]]

Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.

required
branch_start_coords NDArray[int32]

An Nx2 numpy array of the coordinates of the branches connected to the node.

required
ordered_branches list[NDArray[int32]]

List of numpy arrays of ordered branch coordinates.

required
pairs NDArray[int32]

Nx2 numpy array of pairs of branches that are matched through a node.

required
average_trace_advised bool

Flag to determine whether to use the average trace.

required
image_shape tuple[int]

The shape of the image, to create a mask from.

required

Returns:

Type Description
tuple[NDArray[int32], NDArray[int32]]

The branch image and the average image.

Source code in topostats\tracing\nodestats.py
@staticmethod
def add_branches_to_labelled_image(
    branch_under_over_order: npt.NDArray[np.int32],
    matched_branches: dict[int, MatchedBranch],
    masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]],
    branch_start_coords: npt.NDArray[np.int32],
    ordered_branches: list[npt.NDArray[np.int32]],
    pairs: npt.NDArray[np.int32],
    average_trace_advised: bool,
    image_shape: tuple[int, int],
) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
    """
    Add branches to a labelled image.

    Parameters
    ----------
    branch_under_over_order : npt.NDArray[np.int32]
        The order of the branches.
    matched_branches : dict[int, dict[str, MatchedBranch]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "ordered_coords" : npt.NDArray[np.int32].
        - "heights" : npt.NDArray[np.number]. Heights of the branches.
        - "distances" :
        - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
    masked_image : dict[int, dict[str, npt.NDArray[np.bool_]]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
    branch_start_coords : npt.NDArray[np.int32]
        An Nx2 numpy array of the coordinates of the branches connected to the node.
    ordered_branches : list[npt.NDArray[np.int32]]
        List of numpy arrays of ordered branch coordinates.
    pairs : npt.NDArray[np.int32]
        Nx2 numpy array of pairs of branches that are matched through a node.
    average_trace_advised : bool
        Flag to determine whether to use the average trace.
    image_shape : tuple[int]
        The shape of the image, to create a mask from.

    Returns
    -------
    tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]
        The branch image and the average image.
    """
    branch_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32)
    avg_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32)

    for i, branch_index in enumerate(branch_under_over_order):
        branch_coords = matched_branches[branch_index]["ordered_coords"]

        # Add the matched branch to the image, starting at index 1
        branch_image[branch_coords[:, 0], branch_coords[:, 1]] = i + 1
        if average_trace_advised:
            # For type safety, check if avg_image is None and skip if so.
            # This is because the type hinting does not allow for None in the array.
            avg_image[masked_image[branch_index]["avg_mask"] != 0] = i + 1

    # Determine branches that were not able to be paired
    unpaired_branches = np.delete(np.arange(0, branch_start_coords.shape[0]), pairs.flatten())
    LOGGER.debug(f"Unpaired branches: {unpaired_branches}")
    # Ensure that unpaired branches start at index I where I is the number of paired branches.
    branch_label = branch_image.max()
    # Add the unpaired branches back to the branch image
    for i in unpaired_branches:
        branch_label += 1
        branch_image[ordered_branches[i][:, 0], ordered_branches[i][:, 1]] = branch_label

    return branch_image, avg_image

analyse_node_branches(p_to_nm: np.float64, reduced_node_area: npt.NDArray[np.int32], branch_start_coords: npt.NDArray[np.int32], max_length_px: np.float64, reduced_skeleton_graph: nx.classes.graph.Graph, image: npt.NDArray[np.number], average_trace_advised: bool, node_coord: tuple[np.int32, np.int32], pair_odd_branches: bool, filename: str, resolution_threshold: np.float64) -> tuple[npt.NDArray[np.int32], dict[int, MatchedBranch], list[npt.NDArray[np.int32]], dict[int, dict[str, npt.NDArray[np.bool_]]], npt.NDArray[np.int32], np.float64 | None] staticmethod

Analyse the branches of a single node.

Parameters:

Name Type Description Default
p_to_nm float64

The pixel to nm scaling factor.

required
reduced_node_area NDArray[int32]

An NxM numpy array of the node in question and the branches connected to it. Node is marked by 3, and branches by 1.

required
branch_start_coords NDArray[int32]

An Nx2 numpy array of the coordinates of the branches connected to the node.

required
max_length_px int32

The maximum length in pixels to traverse along while ordering.

required
reduced_skeleton_graph Graph

The graph representation of the reduced node area.

required
image NDArray[number]

The full image of the grain.

required
average_trace_advised bool

Flag to determine whether to use the average trace.

required
node_coord tuple[int32, int32]

The node coordinates.

required
pair_odd_branches bool

Whether to try and pair odd-branched nodes.

required
filename str

The filename of the image.

required
resolution_threshold float64

The resolution threshold below which to warn the user that the node is difficult to analyse.

required

Returns:

Name Type Description
pairs NDArray[int32]

Nx2 numpy array of pairs of branches that are matched through a node.

matched_branches dict[int, MatchedBranch]]

Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - "distances" : npt.NDArray[np.number]. The accumulating distance along the branch. - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. - "angles" : np.float64. The angle of the branch, added in later steps.

ordered_branches list[NDArray[int32]]

List of numpy arrays of ordered branch coordinates.

masked_image dict[int, dict[str, NDArray[bool_]]]

Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.

branch_under_over_order NDArray[int32]

The order of the branches based on the FWHM.

confidence float64 | None

The confidence of the crossing. Optional.

Source code in topostats\tracing\nodestats.py
@staticmethod
def analyse_node_branches(
    p_to_nm: np.float64,
    reduced_node_area: npt.NDArray[np.int32],
    branch_start_coords: npt.NDArray[np.int32],
    max_length_px: np.float64,
    reduced_skeleton_graph: nx.classes.graph.Graph,
    image: npt.NDArray[np.number],
    average_trace_advised: bool,
    node_coord: tuple[np.int32, np.int32],
    pair_odd_branches: bool,
    filename: str,
    resolution_threshold: np.float64,
) -> tuple[
    npt.NDArray[np.int32],
    dict[int, MatchedBranch],
    list[npt.NDArray[np.int32]],
    dict[int, dict[str, npt.NDArray[np.bool_]]],
    npt.NDArray[np.int32],
    np.float64 | None,
]:
    """
    Analyse the branches of a single node.

    Parameters
    ----------
    p_to_nm : np.float64
        The pixel to nm scaling factor.
    reduced_node_area : npt.NDArray[np.int32]
        An NxM numpy array of the node in question and the branches connected to it.
        Node is marked by 3, and branches by 1.
    branch_start_coords : npt.NDArray[np.int32]
        An Nx2 numpy array of the coordinates of the branches connected to the node.
    max_length_px : np.int32
        The maximum length in pixels to traverse along while ordering.
    reduced_skeleton_graph : nx.classes.graph.Graph
        The graph representation of the reduced node area.
    image : npt.NDArray[np.number]
        The full image of the grain.
    average_trace_advised : bool
        Flag to determine whether to use the average trace.
    node_coord : tuple[np.int32, np.int32]
        The node coordinates.
    pair_odd_branches : bool
        Whether to try and pair odd-branched nodes.
    filename : str
        The filename of the image.
    resolution_threshold : np.float64
        The resolution threshold below which to warn the user that the node is difficult to analyse.

    Returns
    -------
    pairs: npt.NDArray[np.int32]
        Nx2 numpy array of pairs of branches that are matched through a node.
    matched_branches: dict[int, MatchedBranch]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "ordered_coords" : npt.NDArray[np.int32].
        - "heights" : npt.NDArray[np.number]. Heights of the branches.
        - "distances" : npt.NDArray[np.number]. The accumulating distance along the branch.
        - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
        - "angles" : np.float64. The angle of the branch, added in later steps.
    ordered_branches: list[npt.NDArray[np.int32]]
        List of numpy arrays of ordered branch coordinates.
    masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
    branch_under_over_order: npt.NDArray[np.int32]
        The order of the branches based on the FWHM.
    confidence: np.float64 | None
        The confidence of the crossing. Optional.
    """
    if not p_to_nm <= resolution_threshold:
        LOGGER.debug(f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse.")

    # Pixel-wise order the branches coming from the node and calculate the starting vector for each branch
    ordered_branches, singlet_branch_vectors = nodeStats.get_ordered_branches_and_vectors(
        reduced_node_area, branch_start_coords, max_length_px
    )

    # Pair the singlet branch vectors based on their suitability using vector orientation.
    if len(branch_start_coords) % 2 == 0 or pair_odd_branches:
        pairs = nodeStats.pair_vectors(np.asarray(singlet_branch_vectors))
    else:
        pairs = np.array([], dtype=np.int32)

    # Match the branches up
    matched_branches, masked_image = nodeStats.join_matching_branches_through_node(
        pairs,
        ordered_branches,
        reduced_skeleton_graph,
        image,
        average_trace_advised,
        node_coord,
        filename,
    )

    # Redo the FWHMs after the processing for more accurate determination of under/overs.
    hms = []
    for _, values in matched_branches.items():
        hms.append(values["fwhm"]["half_maxs"][2])
    for _, values in matched_branches.items():
        values["fwhm"] = nodeStats.calculate_fwhm(values["heights"], values["distances"], hm=max(hms))

    # Get the confidence of the crossing
    crossing_fwhms = []
    for _, values in matched_branches.items():
        crossing_fwhms.append(values["fwhm"]["fwhm"])
    if len(crossing_fwhms) <= 1:
        confidence = None
    else:
        crossing_fwhm_combinations = list(combinations(crossing_fwhms, 2))
        confidence = np.float64(nodeStats.cross_confidence(crossing_fwhm_combinations))

    # Order the branch indexes based on the FWHM of the branches.
    branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(crossing_fwhms))]

    return (
        pairs,
        matched_branches,
        ordered_branches,
        masked_image,
        branch_under_over_order,
        confidence,
        singlet_branch_vectors,
    )

analyse_nodes(max_branch_length: float = 20) -> None

Obtain the main analyses for the nodes of a single molecule along the 'max_branch_length'(nm) from the node.

Parameters:

Name Type Description Default
max_branch_length float

The side length of the box around the node to analyse (in nm).

20
Source code in topostats\tracing\nodestats.py
def analyse_nodes(self, max_branch_length: float = 20) -> None:
    """
    Obtain the main analyses for the nodes of a single molecule along the 'max_branch_length'(nm) from the node.

    Parameters
    ----------
    max_branch_length : float
        The side length of the box around the node to analyse (in nm).
    """
    # Get coordinates of nodes
    # This is a numpy array of coords, shape Nx2
    assert self.node_centre_mask is not None, "Node centre mask is not defined."
    node_coords: npt.NDArray[np.int32] = np.argwhere(self.node_centre_mask.copy() == 3)

    # Check whether average trace resides inside the grain mask
    # Checks if we dilate the skeleton once or twice, then all the pixels should fit in the grain mask
    dilate = binary_dilation(self.skeleton, iterations=2)
    # This flag determines whether to use average of 3 traces in calculation of FWHM
    average_trace_advised = dilate[self.smoothed_mask == 1].sum() == dilate.sum()
    LOGGER.debug(f"[{self.filename}] : Branch height traces will be averaged: {average_trace_advised}")

    # Iterate over the nodes and analyse the branches
    matched_branches = None
    branch_image = None
    avg_image = np.zeros_like(self.image)
    real_node_count = 0
    for node_no, (node_x, node_y) in enumerate(node_coords):
        unmatched_branches = {}
        error = False

        # Get branches relevant to the node
        max_length_px = max_branch_length / (self.pixel_to_nm_scaling * 1)
        reduced_node_area: npt.NDArray[np.int32] = nodeStats.only_centre_branches(
            self.connected_nodes, np.array([node_x, node_y])
        )
        # Reduced skel graph is a networkx graph of the reduced node area.
        reduced_skel_graph: nx.classes.graph.Graph = nodeStats.skeleton_image_to_graph(reduced_node_area)

        # Binarise the reduced node area
        branch_mask = reduced_node_area.copy()
        branch_mask[branch_mask == 3] = 0
        branch_mask[branch_mask == 2] = 1
        node_coords = np.argwhere(reduced_node_area == 3)

        # Find the starting coordinates of any branches connected to the node
        branch_start_coords = self.find_branch_starts(reduced_node_area)

        # Stop processing if nib (node has 2 branches)
        if branch_start_coords.shape[0] <= 2:
            LOGGER.debug(
                f"node {node_no} has only two branches - skipped & nodes removed.{len(node_coords)}"
                "pixels in nib node."
            )
        else:
            try:
                real_node_count += 1
                LOGGER.debug(f"Node: {real_node_count}")

                # Analyse the node branches
                (
                    pairs,
                    matched_branches,
                    ordered_branches,
                    masked_image,
                    branch_under_over_order,
                    confidence,
                    singlet_branch_vectors,
                ) = nodeStats.analyse_node_branches(
                    p_to_nm=self.pixel_to_nm_scaling,
                    reduced_node_area=reduced_node_area,
                    branch_start_coords=branch_start_coords,
                    max_length_px=max_length_px,
                    reduced_skeleton_graph=reduced_skel_graph,
                    image=self.image,
                    average_trace_advised=average_trace_advised,
                    node_coord=(node_x, node_y),
                    pair_odd_branches=self.pair_odd_branches,
                    filename=self.filename,
                    resolution_threshold=np.float64(1000 / 512),
                )

                # Add the analysed branches to the labelled image
                branch_image, avg_image = nodeStats.add_branches_to_labelled_image(
                    branch_under_over_order=branch_under_over_order,
                    matched_branches=matched_branches,
                    masked_image=masked_image,
                    branch_start_coords=branch_start_coords,
                    ordered_branches=ordered_branches,
                    pairs=pairs,
                    average_trace_advised=average_trace_advised,
                    image_shape=(self.image.shape[0], self.image.shape[1]),
                )

                # Calculate crossing angles of unpaired branches and add to stats dict
                nodestats_calc_singlet_angles_result = nodeStats.calc_angles(np.asarray(singlet_branch_vectors))
                angles_between_singlet_branch_vectors: npt.NDArray[np.float64] = (
                    nodestats_calc_singlet_angles_result[0]
                )

                for branch_index, angle in enumerate(angles_between_singlet_branch_vectors):
                    unmatched_branches[branch_index] = {"angles": angle}

                # Get the vector of each branch based on ordered_coords. Ordered_coords is only the first N nm
                # of the branch so this is just a general vibe on what direction a branch is going.
                if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches:
                    vectors: list[npt.NDArray[np.float64]] = []
                    for _, values in matched_branches.items():
                        vectors.append(nodeStats.get_vector(values["ordered_coords"], np.array([node_x, node_y])))
                    # Calculate angles between the vectors
                    nodestats_calc_angles_result = nodeStats.calc_angles(np.asarray(vectors))
                    angles_between_vectors_along_branch: npt.NDArray[np.float64] = nodestats_calc_angles_result[0]
                    for branch_index, angle in enumerate(angles_between_vectors_along_branch):
                        if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches:
                            matched_branches[branch_index]["angles"] = angle
                else:
                    self.image_dict["grain"]["grain_skeleton"][node_coords[:, 0], node_coords[:, 1]] = 0

                # Eg: length 2 array: [array([ nan, 79.00]), array([79.00, 0.0])]
                # angles_between_vectors_along_branch

            except ResolutionError:
                LOGGER.debug(f"Node stats skipped as resolution too low: {self.pixel_to_nm_scaling}nm per pixel")
                error = True

            self.node_dicts[f"node_{real_node_count}"] = {
                "error": error,
                "pixel_to_nm_scaling": self.pixel_to_nm_scaling,
                "branch_stats": matched_branches,
                "unmatched_branch_stats": unmatched_branches,
                "node_coords": node_coords,
                "confidence": confidence,
            }

            assert reduced_node_area is not None, "Reduced node area is not defined."
            assert branch_image is not None, "Branch image is not defined."
            assert avg_image is not None, "Average image is not defined."
            node_images_dict: dict[str, npt.NDArray[np.int32]] = {
                "node_area_skeleton": reduced_node_area,
                "node_branch_mask": branch_image,
                "node_avg_mask": avg_image,
            }
            self.image_dict["nodes"][f"node_{real_node_count}"] = node_images_dict

        self.all_connected_nodes[self.connected_nodes != 0] = self.connected_nodes[self.connected_nodes != 0]

average_crossing_confs(node_dict) -> None | float staticmethod

Return the average crossing confidence of all crossings in the molecule.

Parameters:

Name Type Description Default
node_dict dict

A dictionary containing node statistics and information.

required

Returns:

Type Description
Union[None, float]

The value of minimum confidence or none if not possible.

Source code in topostats\tracing\nodestats.py
@staticmethod
def average_crossing_confs(node_dict) -> None | float:
    """
    Return the average crossing confidence of all crossings in the molecule.

    Parameters
    ----------
    node_dict : dict
        A dictionary containing node statistics and information.

    Returns
    -------
    Union[None, float]
        The value of minimum confidence or none if not possible.
    """
    sum_conf = 0
    valid_confs = 0
    for _, (_, values) in enumerate(node_dict.items()):
        confidence = values["confidence"]
        if confidence is not None:
            sum_conf += confidence
            valid_confs += 1
    try:
        return sum_conf / valid_confs
    except ZeroDivisionError:
        return None

average_height_trace(img: npt.NDArray, branch_mask: npt.NDArray, branch_coords: npt.NDArray, centre=(0, 0)) -> tuple staticmethod

Average two side-by-side ordered skeleton distance and height traces.

Dilate the original branch to create two additional side-by-side branches in order to get a more accurate average of the height traces. This function produces the common distances between these 3 branches, and their averaged heights.

Parameters:

Name Type Description Default
img NDArray

An array of numbers pertaining to an image.

required
branch_mask NDArray

A binary array of the branch, must share the same dimensions as the image.

required
branch_coords NDArray

Ordered coordinates of the branch mask.

required
centre Union[float, None]

The coordinates to centre the branch around.

(0, 0)

Returns:

Type Description
tuple

A tuple of the averaged heights from the linetrace and their corresponding distances from the crossing.

Source code in topostats\tracing\nodestats.py
@staticmethod
def average_height_trace(  # noqa: C901
    img: npt.NDArray, branch_mask: npt.NDArray, branch_coords: npt.NDArray, centre=(0, 0)
) -> tuple:
    """
    Average two side-by-side ordered skeleton distance and height traces.

    Dilate the original branch to create two additional side-by-side branches
    in order to get a more accurate average of the height traces. This function produces
    the common distances between these 3 branches, and their averaged heights.

    Parameters
    ----------
    img : npt.NDArray
        An array of numbers pertaining to an image.
    branch_mask : npt.NDArray
        A binary array of the branch, must share the same dimensions as the image.
    branch_coords : npt.NDArray
        Ordered coordinates of the branch mask.
    centre : Union[float, None]
        The coordinates to centre the branch around.

    Returns
    -------
    tuple
        A tuple of the averaged heights from the linetrace and their corresponding distances
        from the crossing.
    """
    # get heights and dists of the original (middle) branch
    branch_dist = nodeStats.coord_dist_rad(branch_coords, centre)
    # branch_dist = self.coord_dist(branch_coords)
    branch_heights = img[branch_coords[:, 0], branch_coords[:, 1]]
    branch_dist, branch_heights = nodeStats.average_uniques(
        branch_dist, branch_heights
    )  # needs to be paired with coord_dist_rad
    dist_zero_point = branch_dist[
        np.argmin(np.sqrt((branch_coords[:, 0] - centre[0]) ** 2 + (branch_coords[:, 1] - centre[1]) ** 2))
    ]
    branch_dist_norm = branch_dist - dist_zero_point  # - 0  # branch_dist[branch_heights.argmax()]

    # want to get a 3 pixel line trace, one on each side of orig
    dilate = binary_dilation(branch_mask, iterations=1)
    dilate = nodeStats.fill_holes(dilate)
    dilate_minus = np.where(dilate != branch_mask, 1, 0)
    dilate2 = binary_dilation(dilate, iterations=1)
    dilate2[(dilate == 1) | (branch_mask == 1)] = 0
    labels = label(dilate2)
    # Cleanup stages - re-entering, early terminating, closer traces
    #   if parallel trace out and back in zone, can get > 2 labels
    labels = nodeStats._remove_re_entering_branches(labels, remaining_branches=2)
    #   if parallel trace doesn't exit window, can get 1 label
    #       occurs when skeleton has poor connections (extra branches which cut corners)
    if labels.max() == 1:
        conv = convolve_skeleton(branch_mask)
        endpoints = np.argwhere(conv == 2)
        for endpoint in endpoints:  # may be >1 endpoint
            para_trace_coords = np.argwhere(labels == 1)
            abs_diff = np.absolute(para_trace_coords - endpoint).sum(axis=1)
            min_idxs = np.where(abs_diff == abs_diff.min())
            trace_coords_remove = para_trace_coords[min_idxs]
            labels[trace_coords_remove[:, 0], trace_coords_remove[:, 1]] = 0
        labels = label(labels)
    #   reduce binary dilation distance
    parallel = np.zeros_like(branch_mask).astype(np.int32)
    for i in range(1, labels.max() + 1):
        single = labels.copy()
        single[single != i] = 0
        single[single == i] = 1
        sing_dil = binary_dilation(single)
        parallel[(sing_dil == dilate_minus) & (sing_dil == 1)] = i
    labels = parallel.copy()

    binary = labels.copy()
    binary[binary != 0] = 1
    binary += branch_mask

    # get and order coords, then get heights and distances relitive to node centre / highest point
    heights = []
    distances = []
    for i in np.unique(labels)[1:]:
        trace_img = np.where(labels == i, 1, 0)
        trace_img = getSkeleton(img, trace_img, method="zhang").get_skeleton()
        trace = order_branch(trace_img, branch_coords[0])
        height_trace = img[trace[:, 0], trace[:, 1]]
        dist = nodeStats.coord_dist_rad(trace, centre)  # self.coord_dist(trace)
        dist, height_trace = nodeStats.average_uniques(dist, height_trace)  # needs to be paired with coord_dist_rad
        heights.append(height_trace)
        distances.append(
            dist - dist_zero_point  # - 0
        )  # branch_dist[branch_heights.argmax()]) #dist[central_heights.argmax()])
    # Make like coord system using original branch
    avg1 = []
    avg2 = []
    for mid_dist in branch_dist_norm:
        for i, (distance, height) in enumerate(zip(distances, heights)):
            # check if distance already in traces array
            if (mid_dist == distance).any():
                idx = np.where(mid_dist == distance)
                if i == 0:
                    avg1.append([mid_dist, height[idx][0]])
                else:
                    avg2.append([mid_dist, height[idx][0]])
            # if not, linearly interpolate the mid-branch value
            else:
                # get index after and before the mid branches' x coord
                xidxs = nodeStats.above_below_value_idx(distance, mid_dist)
                if xidxs is None:
                    pass  # if indexes outside of range, pass
                else:
                    point1 = [distance[xidxs[0]], height[xidxs[0]]]
                    point2 = [distance[xidxs[1]], height[xidxs[1]]]
                    y = nodeStats.lin_interp(point1, point2, xvalue=mid_dist)
                    if i == 0:
                        avg1.append([mid_dist, y])
                    else:
                        avg2.append([mid_dist, y])
    avg1 = np.asarray(avg1)
    avg2 = np.asarray(avg2)
    # ensure arrays are same length to average
    temp_x = branch_dist_norm[np.isin(branch_dist_norm, avg1[:, 0])]
    common_dists = avg2[:, 0][np.isin(avg2[:, 0], temp_x)]

    common_avg_branch_heights = branch_heights[np.isin(branch_dist_norm, common_dists)]
    common_avg1_heights = avg1[:, 1][np.isin(avg1[:, 0], common_dists)]
    common_avg2_heights = avg2[:, 1][np.isin(avg2[:, 0], common_dists)]

    average_heights = (common_avg_branch_heights + common_avg1_heights + common_avg2_heights) / 3
    return (
        common_dists,
        average_heights,
        binary,
        [[heights[0], branch_heights, heights[1]], [distances[0], branch_dist_norm, distances[1]]],
    )

average_uniques(arr1: npt.NDArray, arr2: npt.NDArray) -> tuple staticmethod

Obtain the unique values of both arrays, and the average of common values.

Parameters:

Name Type Description Default
arr1 NDArray

An array.

required
arr2 NDArray

An array.

required

Returns:

Type Description
tuple

The unique values of both arrays, and the averaged common values.

Source code in topostats\tracing\nodestats.py
@staticmethod
def average_uniques(arr1: npt.NDArray, arr2: npt.NDArray) -> tuple:
    """
    Obtain the unique values of both arrays, and the average of common values.

    Parameters
    ----------
    arr1 : npt.NDArray
        An array.
    arr2 : npt.NDArray
        An array.

    Returns
    -------
    tuple
        The unique values of both arrays, and the averaged common values.
    """
    arr1_uniq, index = np.unique(arr1, return_index=True)
    arr2_new = np.zeros_like(arr1_uniq).astype(np.float64)
    for i, val in enumerate(arr1[index]):
        mean = arr2[arr1 == val].mean()
        arr2_new[i] += mean

    return arr1[index], arr2_new

best_matches(arr: npt.NDArray, max_weight_matching: bool = True) -> npt.NDArray staticmethod

Turn a matrix into a graph and calculates the best matching index pairs.

Parameters:

Name Type Description Default
arr NDArray

Transpose symmetric MxM array where the value of index i, j represents a weight between i and j.

required
max_weight_matching bool

Whether to obtain best matching pairs via maximum weight, or minimum weight matching.

True

Returns:

Type Description
NDArray

Array of pairs of indexes.

Source code in topostats\tracing\nodestats.py
@staticmethod
def best_matches(arr: npt.NDArray, max_weight_matching: bool = True) -> npt.NDArray:
    """
    Turn a matrix into a graph and calculates the best matching index pairs.

    Parameters
    ----------
    arr : npt.NDArray
        Transpose symmetric MxM array where the value of index i, j represents a weight between i and j.
    max_weight_matching : bool
        Whether to obtain best matching pairs via maximum weight, or minimum weight matching.

    Returns
    -------
    npt.NDArray
        Array of pairs of indexes.
    """
    if max_weight_matching:
        G = nodeStats.create_weighted_graph(arr)
        matching = np.array(list(nx.max_weight_matching(G, maxcardinality=True)))
    else:
        np.fill_diagonal(arr, arr.max() + 1)
        G = nodeStats.create_weighted_graph(arr)
        matching = np.array(list(nx.min_weight_matching(G)))
    return matching

binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray staticmethod

Create a binary path following the straight line between 2 points.

Parameters:

Name Type Description Default
start NDArray

A coordinate.

required
end NDArray

Another coordinate.

required

Returns:

Type Description
NDArray

An Nx2 coordinate array that the line passes through.

Source code in topostats\tracing\nodestats.py
@staticmethod
def binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray:
    """
    Create a binary path following the straight line between 2 points.

    Parameters
    ----------
    start : npt.NDArray
        A coordinate.
    end : npt.NDArray
        Another coordinate.

    Returns
    -------
    npt.NDArray
        An Nx2 coordinate array that the line passes through.
    """
    arr = []
    m_swap = False
    x_swap = False
    slope = (end - start)[1] / (end - start)[0]

    if abs(slope) > 1:  # swap x and y if slope will cause skips
        start, end = start[::-1], end[::-1]
        slope = 1 / slope
        m_swap = True

    if start[0] > end[0]:  # swap x coords if coords wrong way around
        start, end = end, start
        x_swap = True

    # code assumes slope < 1 hence swap
    x_start, y_start = start
    x_end, _ = end
    for x in range(x_start, x_end + 1):
        y_true = slope * (x - x_start) + y_start
        y_pixel = np.round(y_true)
        arr.append([x, y_pixel])

    if m_swap:  # if swapped due to slope, return
        arr = np.asarray(arr)[:, [1, 0]].reshape(-1, 2).astype(int)
        if x_swap:
            return arr[::-1]
        return arr
    arr = np.asarray(arr).reshape(-1, 2).astype(int)
    if x_swap:
        return arr[::-1]
    return arr

calc_angles(vectors: npt.NDArray) -> npt.NDArray[np.float64] staticmethod

Calculate the angles between vectors in an array.

Uses the formula:

.. code-block:: RST

cos(theta) = |a|•|b|/|a||b|

Parameters:

Name Type Description Default
vectors NDArray

Array of 2x1 vectors.

required

Returns:

Type Description
NDArray

An array of the cosine of the angles between the vectors.

Source code in topostats\tracing\nodestats.py
@staticmethod
def calc_angles(vectors: npt.NDArray) -> npt.NDArray[np.float64]:
    """
    Calculate the angles between vectors in an array.

    Uses the formula:

    .. code-block:: RST

        cos(theta) = |a|•|b|/|a||b|

    Parameters
    ----------
    vectors : npt.NDArray
        Array of 2x1 vectors.

    Returns
    -------
    npt.NDArray
        An array of the cosine of the angles between the vectors.
    """
    dot = vectors @ vectors.T
    norm = np.diag(dot) ** 0.5
    cos_angles = dot / (norm.reshape(-1, 1) @ norm.reshape(1, -1))
    np.fill_diagonal(cos_angles, 1)  # ensures vector_x • vector_x angles are 0
    return abs(np.arccos(cos_angles) / np.pi * 180)  # angles in degrees

calculate_fwhm(heights: npt.NDArray, distances: npt.NDArray, hm: float | None = None) -> dict[str, np.float64 | list[np.float64 | float | None]] staticmethod

Calculate the FWHM value.

First identifyies the HM then finding the closest values in the distances array and using linear interpolation to calculate the FWHM.

Parameters:

Name Type Description Default
heights NDArray

Array of heights.

required
distances NDArray

Array of distances.

required
hm Union[None, float]

The halfmax value to match (if wanting the same HM between curves), by default None.

None

Returns:

Type Description
tuple[float, list, list]

The FWHM value, [distance at hm for 1st half of trace, distance at hm for 2nd half of trace, HM value], [index of the highest point, distance at highest point, height at highest point].

Source code in topostats\tracing\nodestats.py
@staticmethod
def calculate_fwhm(
    heights: npt.NDArray, distances: npt.NDArray, hm: float | None = None
) -> dict[str, np.float64 | list[np.float64 | float | None]]:
    """
    Calculate the FWHM value.

    First identifyies the HM then finding the closest values in the distances array and using
    linear interpolation to calculate the FWHM.

    Parameters
    ----------
    heights : npt.NDArray
        Array of heights.
    distances : npt.NDArray
        Array of distances.
    hm : Union[None, float], optional
        The halfmax value to match (if wanting the same HM between curves), by default None.

    Returns
    -------
    tuple[float, list, list]
        The FWHM value, [distance at hm for 1st half of trace, distance at hm for 2nd half of trace,
        HM value], [index of the highest point, distance at highest point, height at highest point].
    """
    centre_fraction = int(len(heights) * 0.2)  # in case zone approaches another node, look around centre for max
    if centre_fraction == 0:
        high_idx = np.argmax(heights)
    else:
        high_idx = np.argmax(heights[centre_fraction:-centre_fraction]) + centre_fraction
    # get array halves to find first points that cross hm
    arr1 = heights[:high_idx][::-1]
    dist1 = distances[:high_idx][::-1]
    arr2 = heights[high_idx:]
    dist2 = distances[high_idx:]
    if hm is None:
        # Get half max
        hm = (heights.max() - heights.min()) / 2 + heights.min()
        # half max value -> try to make it the same as other crossing branch?
        # increase make hm = lowest of peak if it doesn’t hit one side
        if np.min(arr1) > hm:
            arr1_local_min = argrelextrema(arr1, np.less)[-1]  # closest to end
            try:
                hm = arr1[arr1_local_min][0]
            except IndexError:  # index error when no local minima
                hm = np.min(arr1)
        elif np.min(arr2) > hm:
            arr2_local_min = argrelextrema(arr2, np.less)[0]  # closest to start
            try:
                hm = arr2[arr2_local_min][0]
            except IndexError:  # index error when no local minima
                hm = np.min(arr2)
    arr1_hm = nodeStats.interpolate_between_yvalue(x=dist1, y=arr1, yvalue=hm)
    arr2_hm = nodeStats.interpolate_between_yvalue(x=dist2, y=arr2, yvalue=hm)
    fwhm = np.float64(abs(arr2_hm - arr1_hm))
    return {
        "fwhm": fwhm,
        "half_maxs": [arr1_hm, arr2_hm, hm],
        "peaks": [high_idx, distances[high_idx], heights[high_idx]],
    }

compile_metrics() -> None

Add the number of crossings, average and minimum crossing confidence to the metrics dictionary.

Source code in topostats\tracing\nodestats.py
def compile_metrics(self) -> None:
    """Add the number of crossings, average and minimum crossing confidence to the metrics dictionary."""
    self.metrics["num_crossings"] = np.int64((self.node_centre_mask == 3).sum())
    self.metrics["avg_crossing_confidence"] = np.float64(nodeStats.average_crossing_confs(self.node_dicts))
    self.metrics["min_crossing_confidence"] = np.float64(nodeStats.minimum_crossing_confs(self.node_dicts))

connect_close_nodes(conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray

Connect nodes within the 'node_width' boundary distance.

This labels them as part of the same node.

Parameters:

Name Type Description Default
conv_skelly NDArray

A labeled skeleton image with skeleton = 1, endpoints = 2, crossing points =3.

required
node_width float

The width of the dna in the grain, used to connect close nodes.

2.85

Returns:

Type Description
ndarray

The skeleton (label=1) with close nodes connected (label=3).

Source code in topostats\tracing\nodestats.py
def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray:
    """
    Connect nodes within the 'node_width' boundary distance.

    This labels them as part of the same node.

    Parameters
    ----------
    conv_skelly : npt.NDArray
        A labeled skeleton image with skeleton = 1, endpoints = 2, crossing points =3.
    node_width : float
        The width of the dna in the grain, used to connect close nodes.

    Returns
    -------
    np.ndarray
        The skeleton (label=1) with close nodes connected (label=3).
    """
    self.connected_nodes = conv_skelly.copy()
    nodeless = conv_skelly.copy()
    nodeless[(nodeless == 3) | (nodeless == 2)] = 0  # remove node & termini points
    nodeless_labels = label(nodeless)
    for i in range(1, nodeless_labels.max() + 1):
        if nodeless[nodeless_labels == i].size < (node_width / self.pixel_to_nm_scaling):
            # maybe also need to select based on height? and also ensure small branches classified
            self.connected_nodes[nodeless_labels == i] = 3

    return self.connected_nodes

connect_extended_nodes_nearest(connected_nodes: npt.NDArray, node_extend_dist: float = -1) -> npt.NDArray[np.int32]

Extend the odd branched nodes to other odd branched nodes within the 'extend_dist' threshold.

Parameters:

Name Type Description Default
connected_nodes NDArray

A 2D array representing the network with background = 0, skeleton = 1, endpoints = 2, node_centres = 3.

required
node_extend_dist int | float

The distance over which to connect odd-branched nodes, by default -1 for no-limit.

-1

Returns:

Type Description
NDArray[int32]

Connected nodes array with odd-branched nodes connected.

Source code in topostats\tracing\nodestats.py
def connect_extended_nodes_nearest(
    self, connected_nodes: npt.NDArray, node_extend_dist: float = -1
) -> npt.NDArray[np.int32]:
    """
    Extend the odd branched nodes to other odd branched nodes within the 'extend_dist' threshold.

    Parameters
    ----------
    connected_nodes : npt.NDArray
        A 2D array representing the network with background = 0, skeleton = 1, endpoints = 2,
        node_centres = 3.
    node_extend_dist : int | float, optional
        The distance over which to connect odd-branched nodes, by default -1 for no-limit.

    Returns
    -------
    npt.NDArray[np.int32]
        Connected nodes array with odd-branched nodes connected.
    """
    just_nodes = np.where(connected_nodes == 3, 1, 0)  # remove branches & termini points
    labelled_nodes = label(just_nodes)

    just_branches = np.where(connected_nodes == 1, 1, 0)  # remove node & termini points
    just_branches[connected_nodes == 1] = labelled_nodes.max() + 1
    labelled_branches = label(just_branches)

    nodes_with_branch_starting_coords = find_branches_for_nodes(
        network_array_representation=connected_nodes,
        labelled_nodes=labelled_nodes,
        labelled_branches=labelled_branches,
    )

    # If there is only one node, then there is no need to connect the nodes since there is nothing to
    # connect it to. Return the original connected_nodes instead.
    if len(nodes_with_branch_starting_coords) <= 1:
        self.connected_nodes = connected_nodes
        return self.connected_nodes

    assert self.whole_skel_graph is not None, "Whole skeleton graph is not defined."  # for type safety
    shortest_node_dists, shortest_dists_branch_idxs, _shortest_dist_coords = calculate_shortest_branch_distances(
        nodes_with_branch_starting_coords=nodes_with_branch_starting_coords,
        whole_skeleton_graph=self.whole_skel_graph,
    )

    # Matches is an Nx2 numpy array of indexes of the best matching nodes.
    # Eg: np.array([[1, 0], [2, 3]]) means that the best matching nodes are
    # node 1 and node 0, and node 2 and node 3.
    matches: npt.NDArray[np.int32] = self.best_matches(shortest_node_dists, max_weight_matching=False)

    # Connect the nodes by their best matches, using the shortest distances between their branch starts.
    connected_nodes = connect_best_matches(
        network_array_representation=connected_nodes,
        whole_skeleton_graph=self.whole_skel_graph,
        match_indexes=matches,
        shortest_distances_between_nodes=shortest_node_dists,
        shortest_distances_branch_indexes=shortest_dists_branch_idxs,
        emanating_branch_starts_by_node=nodes_with_branch_starting_coords,
        extend_distance=node_extend_dist,
    )

    self.connected_nodes = connected_nodes
    return self.connected_nodes

coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray staticmethod

Calculate the distance from the centre coordinate to a point along the ordered coordinates.

This differs to traversal along the coordinates taken. This also averages any common distance values and makes those in the trace before the node index negative.

Parameters:

Name Type Description Default
coords NDArray

Nx2 array of branch coordinates.

required
centre NDArray

A 1x2 array of the centre coordinates to identify a 0 point for the node.

required
pixel_to_nm_scaling float

The pixel to nanometer scaling factor to provide real units, by default 1.

1

Returns:

Type Description
NDArray

A Nx1 array of the distance from the node centre.

Source code in topostats\tracing\nodestats.py
@staticmethod
def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray:
    """
    Calculate the distance from the centre coordinate to a point along the ordered coordinates.

    This differs to traversal along the coordinates taken. This also averages any common distance
    values and makes those in the trace before the node index negative.

    Parameters
    ----------
    coords : npt.NDArray
        Nx2 array of branch coordinates.
    centre : npt.NDArray
        A 1x2 array of the centre coordinates to identify a 0 point for the node.
    pixel_to_nm_scaling : float, optional
        The pixel to nanometer scaling factor to provide real units, by default 1.

    Returns
    -------
    npt.NDArray
        A Nx1 array of the distance from the node centre.
    """
    diff_coords = coords - centre
    if np.all(coords == centre, axis=1).sum() == 0:  # if centre not in coords, reassign centre
        diff_dists = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2)
        centre = coords[np.argmin(diff_dists)]
    cross_idx = np.argwhere(np.all(coords == centre, axis=1))
    rad_dist = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2)
    rad_dist[0 : cross_idx[0][0]] *= -1
    return rad_dist * pixel_to_nm_scaling

create_weighted_graph(matrix: npt.NDArray) -> nx.Graph staticmethod

Create a bipartite graph connecting i <-> j from a square matrix of weights matrix[i, j].

Parameters:

Name Type Description Default
matrix NDArray

Square array of weights between rows and columns.

required

Returns:

Type Description
Graph

Bipatrite graph with edge weight i->j matching matrix[i,j].

Source code in topostats\tracing\nodestats.py
@staticmethod
def create_weighted_graph(matrix: npt.NDArray) -> nx.Graph:
    """
    Create a bipartite graph connecting i <-> j from a square matrix of weights matrix[i, j].

    Parameters
    ----------
    matrix : npt.NDArray
        Square array of weights between rows and columns.

    Returns
    -------
    nx.Graph
        Bipatrite graph with edge weight i->j matching matrix[i,j].
    """
    n = len(matrix)
    G = nx.Graph()
    for i in range(n):
        for j in range(i + 1, n):
            G.add_edge(i, j, weight=matrix[i, j])
    return G

cross_confidence(pair_combinations: list) -> float staticmethod

Obtain the average confidence of the combinations using a reciprical function.

Parameters:

Name Type Description Default
pair_combinations list

List of length 2 combinations of FWHM values.

required

Returns:

Type Description
float

The average crossing confidence.

Source code in topostats\tracing\nodestats.py
@staticmethod
def cross_confidence(pair_combinations: list) -> float:
    """
    Obtain the average confidence of the combinations using a reciprical function.

    Parameters
    ----------
    pair_combinations : list
        List of length 2 combinations of FWHM values.

    Returns
    -------
    float
        The average crossing confidence.
    """
    c = 0
    for pair in pair_combinations:
        c += nodeStats.recip(pair)
    return c / len(pair_combinations)

fill_holes(mask: npt.NDArray) -> npt.NDArray staticmethod

Fill all holes within a binary mask.

Parameters:

Name Type Description Default
mask NDArray

Binary array of object.

required

Returns:

Type Description
NDArray

Binary array of object with any interior holes filled in.

Source code in topostats\tracing\nodestats.py
@staticmethod
def fill_holes(mask: npt.NDArray) -> npt.NDArray:
    """
    Fill all holes within a binary mask.

    Parameters
    ----------
    mask : npt.NDArray
        Binary array of object.

    Returns
    -------
    npt.NDArray
        Binary array of object with any interior holes filled in.
    """
    inv_mask = np.where(mask != 0, 0, 1)
    lbl_inv = label(inv_mask, connectivity=1)
    idxs, counts = np.unique(lbl_inv, return_counts=True)
    max_idx = idxs[np.argmax(counts)]
    return np.where(lbl_inv != max_idx, 1, 0)

find_branch_starts(reduced_node_image: npt.NDArray) -> npt.NDArray staticmethod

Find the coordinates where the branches connect to the node region through binary dilation of the node.

Parameters:

Name Type Description Default
reduced_node_image NDArray

A 2D numpy array containing a single node region (=3) and its connected branches (=1).

required

Returns:

Type Description
NDArray

Coordinate array of pixels next to crossing points (=3 in input).

Source code in topostats\tracing\nodestats.py
@staticmethod
def find_branch_starts(reduced_node_image: npt.NDArray) -> npt.NDArray:
    """
    Find the coordinates where the branches connect to the node region through binary dilation of the node.

    Parameters
    ----------
    reduced_node_image : npt.NDArray
        A 2D numpy array containing a single node region (=3) and its connected branches (=1).

    Returns
    -------
    npt.NDArray
        Coordinate array of pixels next to crossing points (=3 in input).
    """
    node = np.where(reduced_node_image == 3, 1, 0)
    nodeless = np.where(reduced_node_image == 1, 1, 0)
    thick_node = binary_dilation(node, structure=np.ones((3, 3)))

    return np.argwhere(thick_node * nodeless == 1)

gaussian(x: npt.NDArray, h: float, mean: float, sigma: float) staticmethod

Apply the gaussian function.

Parameters:

Name Type Description Default
x NDArray

X values to be passed into the gaussian.

required
h float

The peak height of the gaussian.

required
mean float

The mean of the x values.

required
sigma float

The standard deviation of the image.

required

Returns:

Type Description
NDArray

The y-values of the gaussian performed on the x values.

Source code in topostats\tracing\nodestats.py
@staticmethod
def gaussian(x: npt.NDArray, h: float, mean: float, sigma: float):
    """
    Apply the gaussian function.

    Parameters
    ----------
    x : npt.NDArray
        X values to be passed into the gaussian.
    h : float
        The peak height of the gaussian.
    mean : float
        The mean of the x values.
    sigma : float
        The standard deviation of the image.

    Returns
    -------
    npt.NDArray
        The y-values of the gaussian performed on the x values.
    """
    return h * np.exp(-((x - mean) ** 2) / (2 * sigma**2))

get_node_stats() -> tuple[dict, dict]

Run the workflow to obtain the node statistics.

.. code-block:: RST

node_dict key structure:  <grain_number>
                            â””-> <node_number>
                                |-> 'error'
                                â””-> 'node_coords'
                                â””-> 'branch_stats'
                                    â””-> <branch_number>
                                        |-> 'ordered_coords'
                                        |-> 'heights'
                                        |-> 'gaussian_fit'
                                        |-> 'fwhm'
                                        â””-> 'angles'

image_dict key structure:  'nodes'
                                <node_number>
                                    |-> 'node_area_skeleton'
                                    |-> 'node_branch_mask'
                                    â””-> 'node_avg_mask
                            'grain'
                                |-> 'grain_image'
                                |-> 'grain_mask'
                                â””-> 'grain_skeleton'

Returns:

Type Description
tuple[dict, dict]

Dictionaries of the node_information and images.

Source code in topostats\tracing\nodestats.py
def get_node_stats(self) -> tuple[dict, dict]:
    """
    Run the workflow to obtain the node statistics.

    .. code-block:: RST

        node_dict key structure:  <grain_number>
                                    â””-> <node_number>
                                        |-> 'error'
                                        â””-> 'node_coords'
                                        â””-> 'branch_stats'
                                            â””-> <branch_number>
                                                |-> 'ordered_coords'
                                                |-> 'heights'
                                                |-> 'gaussian_fit'
                                                |-> 'fwhm'
                                                â””-> 'angles'

        image_dict key structure:  'nodes'
                                        <node_number>
                                            |-> 'node_area_skeleton'
                                            |-> 'node_branch_mask'
                                            â””-> 'node_avg_mask
                                    'grain'
                                        |-> 'grain_image'
                                        |-> 'grain_mask'
                                        â””-> 'grain_skeleton'

    Returns
    -------
    tuple[dict, dict]
        Dictionaries of the node_information and images.
    """
    LOGGER.debug(f"Node Stats - Processing Grain: {self.n_grain}")
    self.conv_skelly = convolve_skeleton(self.skeleton)
    if len(self.conv_skelly[self.conv_skelly == 3]) != 0:  # check if any nodes
        LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.")
        # convolve to see crossing and end points
        # self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image)
        # reset skeleton var as tidy branches may have modified it
        self.skeleton = np.where(self.conv_skelly != 0, 1, 0)
        self.image_dict["grain"]["grain_skeleton"] = self.skeleton
        # get graph of skeleton
        self.whole_skel_graph = self.skeleton_image_to_graph(self.skeleton)
        # connect the close nodes
        LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.")
        self.connected_nodes = self.connect_close_nodes(self.conv_skelly, node_width=self.node_joining_length)
        # connect the odd-branch nodes
        self.connected_nodes = self.connect_extended_nodes_nearest(
            self.connected_nodes, node_extend_dist=self.node_extend_dist
        )
        # obtain a mask of node centers and their count
        self.node_centre_mask = self.highlight_node_centres(self.connected_nodes)
        # Begin the hefty crossing analysis
        LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.")
        self.analyse_nodes(max_branch_length=self.branch_pairing_length)
        self.compile_metrics()
    else:
        LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.")
    return self.node_dicts, self.image_dict

get_ordered_branches_and_vectors(reduced_node_area: npt.NDArray[np.int32], branch_start_coords: npt.NDArray[np.int32], max_length_px: np.float64) -> tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]] staticmethod

Get ordered branches and vectors for a node.

Branches are ordered so they are no longer just a disordered set of coordinates, and vectors are calculated to represent the general direction tendency of the branch, this allows for alignment matching later on.

Parameters:

Name Type Description Default
reduced_node_area NDArray[int32]

An NxM numpy array of the node in question and the branches connected to it. Node is marked by 3, and branches by 1.

required
branch_start_coords NDArray[int32]

An Px2 numpy array of coordinates representing the start of branches where P is the number of branches.

required
max_length_px int32

The maximum length in pixels to traverse along while ordering.

required

Returns:

Type Description
tuple[list[NDArray[int32]], list[NDArray[int32]]]

A tuple containing a list of ordered branches and a list of vectors.

Source code in topostats\tracing\nodestats.py
@staticmethod
def get_ordered_branches_and_vectors(
    reduced_node_area: npt.NDArray[np.int32],
    branch_start_coords: npt.NDArray[np.int32],
    max_length_px: np.float64,
) -> tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]:
    """
    Get ordered branches and vectors for a node.

    Branches are ordered so they are no longer just a disordered set of coordinates, and vectors are calculated to
    represent the general direction tendency of the branch, this allows for alignment matching later on.

    Parameters
    ----------
    reduced_node_area : npt.NDArray[np.int32]
        An NxM numpy array of the node in question and the branches connected to it.
        Node is marked by 3, and branches by 1.
    branch_start_coords : npt.NDArray[np.int32]
        An Px2 numpy array of coordinates representing the start of branches where P is the number of branches.
    max_length_px : np.int32
        The maximum length in pixels to traverse along while ordering.

    Returns
    -------
    tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]
        A tuple containing a list of ordered branches and a list of vectors.
    """
    ordered_branches = []
    vectors = []
    nodeless = np.where(reduced_node_area == 1, 1, 0)
    for branch_start_coord in branch_start_coords:
        # Order the branch coordinates so they're no longer just a disordered set of coordinates
        ordered_branch = order_branch_from_start(nodeless.copy(), branch_start_coord, max_length=max_length_px)
        ordered_branches.append(ordered_branch)

        # Calculate vector to represent the general direction tendency of the branch (for alignment matching)
        vector = nodeStats.get_vector(ordered_branch, branch_start_coord)
        vectors.append(vector)

    return ordered_branches, vectors

get_vector(coords: npt.NDArray, origin: npt.NDArray) -> npt.NDArray staticmethod

Calculate the normalised vector of the coordinate means in a branch.

Parameters:

Name Type Description Default
coords NDArray

2xN array of x, y coordinates.

required
origin NDArray

2x1 array of an x, y coordinate.

required

Returns:

Type Description
NDArray

Normalised vector from origin to the mean coordinate.

Source code in topostats\tracing\nodestats.py
@staticmethod
def get_vector(coords: npt.NDArray, origin: npt.NDArray) -> npt.NDArray:
    """
    Calculate the normalised vector of the coordinate means in a branch.

    Parameters
    ----------
    coords : npt.NDArray
        2xN array of x, y coordinates.
    origin : npt.NDArray
        2x1 array of an x, y coordinate.

    Returns
    -------
    npt.NDArray
        Normalised vector from origin to the mean coordinate.
    """
    vector = coords.mean(axis=0) - origin
    norm = np.sqrt(vector @ vector)
    return vector if norm == 0 else vector / norm  # normalise vector so length=1

graph_to_skeleton_image(g: nx.Graph, im_shape: tuple[int]) -> npt.NDArray staticmethod

Convert the skeleton graph back to a binary image.

Parameters:

Name Type Description Default
g Graph

Graph with coordinates as node labels.

required
im_shape tuple[int]

The shape of the image to dump.

required

Returns:

Type Description
NDArray

Skeleton binary image from the graph representation.

Source code in topostats\tracing\nodestats.py
@staticmethod
def graph_to_skeleton_image(g: nx.Graph, im_shape: tuple[int]) -> npt.NDArray:
    """
    Convert the skeleton graph back to a binary image.

    Parameters
    ----------
    g : nx.Graph
        Graph with coordinates as node labels.
    im_shape : tuple[int]
        The shape of the image to dump.

    Returns
    -------
    npt.NDArray
        Skeleton binary image from the graph representation.
    """
    im = np.zeros(im_shape)
    for node in g:
        im[node] = 1

    return im

highlight_node_centres(mask: npt.NDArray) -> npt.NDArray

Calculate the node centres based on height and re-plot on the mask.

Parameters:

Name Type Description Default
mask NDArray

2D array with background = 0, skeleton = 1, endpoints = 2, node_centres = 3.

required

Returns:

Type Description
NDArray

2D array with the highest node coordinate for each node labeled as 3.

Source code in topostats\tracing\nodestats.py
def highlight_node_centres(self, mask: npt.NDArray) -> npt.NDArray:
    """
    Calculate the node centres based on height and re-plot on the mask.

    Parameters
    ----------
    mask : npt.NDArray
        2D array with background = 0, skeleton = 1, endpoints = 2, node_centres = 3.

    Returns
    -------
    npt.NDArray
        2D array with the highest node coordinate for each node labeled as 3.
    """
    small_node_mask = mask.copy()
    small_node_mask[mask == 3] = 1  # remap nodes to skeleton
    big_nodes = mask.copy()
    big_nodes = np.where(mask == 3, 1, 0)  # remove non-nodes & set nodes to 1
    big_node_mask = label(big_nodes)

    for i in np.delete(np.unique(big_node_mask), 0):  # get node indices
        centre = np.unravel_index((self.image * (big_node_mask == i).astype(int)).argmax(), self.image.shape)
        small_node_mask[centre] = 3

    return small_node_mask

interpolate_between_yvalue(x: npt.NDArray, y: npt.NDArray, yvalue: float) -> float staticmethod

Calculate the x value between the two points either side of yvalue in y.

Parameters:

Name Type Description Default
x NDArray

An array of length y.

required
y NDArray

An array of length x.

required
yvalue float

A value within the bounds of the y array.

required

Returns:

Type Description
float

The linearly interpolated x value between the arrays.

Source code in topostats\tracing\nodestats.py
@staticmethod
def interpolate_between_yvalue(x: npt.NDArray, y: npt.NDArray, yvalue: float) -> float:
    """
    Calculate the x value between the two points either side of yvalue in y.

    Parameters
    ----------
    x : npt.NDArray
        An array of length y.
    y : npt.NDArray
        An array of length x.
    yvalue : float
        A value within the bounds of the y array.

    Returns
    -------
    float
        The linearly interpolated x value between the arrays.
    """
    for i in range(len(y) - 1):
        if y[i] <= yvalue <= y[i + 1] or y[i + 1] <= yvalue <= y[i]:  # if points cross through the hm value
            return nodeStats.lin_interp([x[i], y[i]], [x[i + 1], y[i + 1]], yvalue=yvalue)
    return 0

join_matching_branches_through_node(pairs: npt.NDArray[np.int32], ordered_branches: list[npt.NDArray[np.int32]], reduced_skeleton_graph: nx.classes.graph.Graph, image: npt.NDArray[np.number], average_trace_advised: bool, node_coords: tuple[np.int32, np.int32], filename: str) -> tuple[dict[int, MatchedBranch], dict[int, dict[str, npt.NDArray[np.bool_]]]] staticmethod

Join branches that are matched through a node.

Parameters:

Name Type Description Default
pairs NDArray[int32]

Nx2 numpy array of pairs of branches that are matched through a node.

required
ordered_branches list[NDArray[int32]]

List of numpy arrays of ordered branch coordinates.

required
reduced_skeleton_graph Graph

Graph representation of the skeleton.

required
image NDArray[number]

The full image of the grain.

required
average_trace_advised bool

Flag to determine whether to use the average trace.

required
node_coords tuple[int32, int32]

The node coordinates.

required
filename str

The filename of the image.

required

Returns:

Name Type Description
matched_branches dict[int, dict[str, NDArray[number]]]

Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - "distances" : - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.

masked_image dict[int, dict[str, NDArray[bool_]]]

Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.

Source code in topostats\tracing\nodestats.py
@staticmethod
def join_matching_branches_through_node(
    pairs: npt.NDArray[np.int32],
    ordered_branches: list[npt.NDArray[np.int32]],
    reduced_skeleton_graph: nx.classes.graph.Graph,
    image: npt.NDArray[np.number],
    average_trace_advised: bool,
    node_coords: tuple[np.int32, np.int32],
    filename: str,
) -> tuple[dict[int, MatchedBranch], dict[int, dict[str, npt.NDArray[np.bool_]]]]:
    """
    Join branches that are matched through a node.

    Parameters
    ----------
    pairs : npt.NDArray[np.int32]
        Nx2 numpy array of pairs of branches that are matched through a node.
    ordered_branches : list[npt.NDArray[np.int32]]
        List of numpy arrays of ordered branch coordinates.
    reduced_skeleton_graph : nx.classes.graph.Graph
        Graph representation of the skeleton.
    image : npt.NDArray[np.number]
        The full image of the grain.
    average_trace_advised : bool
        Flag to determine whether to use the average trace.
    node_coords : tuple[np.int32, np.int32]
        The node coordinates.
    filename : str
        The filename of the image.

    Returns
    -------
    matched_branches: dict[int, dict[str, npt.NDArray[np.number]]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "ordered_coords" : npt.NDArray[np.int32].
        - "heights" : npt.NDArray[np.number]. Heights of the branches.
        - "distances" :
        - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
    masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]]
        Dictionary where the key is the index of the pair and the value is a dictionary containing the following
        keys:
        - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
    """
    matched_branches: dict[int, MatchedBranch] = {}
    masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = (
        {}
    )  # Masked image is a dictionary of pairs of branches
    for i, (branch_1, branch_2) in enumerate(pairs):
        matched_branches[i] = MatchedBranch(
            ordered_coords=np.array([], dtype=np.int32),
            heights=np.array([], dtype=np.float64),
            distances=np.array([], dtype=np.float64),
            fwhm={},
            angles=None,
        )
        masked_image[i] = {}
        # find close ends by rearranging branch coords
        branch_1_coords, branch_2_coords = nodeStats.order_branches(
            ordered_branches[branch_1], ordered_branches[branch_2]
        )
        # Get graphical shortest path between branch ends on the skeleton
        crossing = nx.shortest_path(
            reduced_skeleton_graph,
            source=tuple(branch_1_coords[-1]),
            target=tuple(branch_2_coords[0]),
            weight="weight",
        )
        crossing = np.asarray(crossing[1:-1])  # remove start and end points & turn into array
        # Branch coords and crossing
        if crossing.shape == (0,):
            branch_coords = np.vstack([branch_1_coords, branch_2_coords])
        else:
            branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords])
        # make images of single branch joined and multiple branches joined
        single_branch_img: npt.NDArray[np.bool_] = np.zeros_like(image).astype(bool)
        single_branch_img[branch_coords[:, 0], branch_coords[:, 1]] = True
        single_branch_coords = order_branch(single_branch_img.astype(bool), [0, 0])
        # calc image-wide coords
        matched_branches[i]["ordered_coords"] = single_branch_coords
        # get heights and trace distance of branch
        try:
            assert average_trace_advised
            distances, heights, mask, _ = nodeStats.average_height_trace(
                image, single_branch_img, single_branch_coords, [node_coords[0], node_coords[1]]
            )
            masked_image[i]["avg_mask"] = mask
        except (
            AssertionError,
            IndexError,
        ) as e:  # Assertion - avg trace not advised, Index - wiggy branches
            LOGGER.debug(f"[{filename}] : avg trace failed with {e}, single trace only.")
            average_trace_advised = False
            distances = nodeStats.coord_dist_rad(single_branch_coords, np.array([node_coords[0], node_coords[1]]))
            # distances = self.coord_dist(single_branch_coords)
            zero_dist = distances[
                np.argmin(
                    np.sqrt(
                        (single_branch_coords[:, 0] - node_coords[0]) ** 2
                        + (single_branch_coords[:, 1] - node_coords[1]) ** 2
                    )
                )
            ]
            heights = image[single_branch_coords[:, 0], single_branch_coords[:, 1]]  # self.hess
            distances = distances - zero_dist
            distances, heights = nodeStats.average_uniques(
                distances, heights
            )  # needs to be paired with coord_dist_rad
        matched_branches[i]["heights"] = heights
        matched_branches[i]["distances"] = distances
        # identify over/under
        matched_branches[i]["fwhm"] = nodeStats.calculate_fwhm(heights, distances)

    return matched_branches, masked_image

keep_biggest_object(mask: npt.NDArray) -> npt.NDArray staticmethod

Retain the largest object in a binary mask.

Parameters:

Name Type Description Default
mask NDArray

Binary mask.

required

Returns:

Type Description
NDArray

A binary mask with only one object.

Source code in topostats\tracing\nodestats.py
@staticmethod
def keep_biggest_object(mask: npt.NDArray) -> npt.NDArray:
    """
    Retain the largest object in a binary mask.

    Parameters
    ----------
    mask : npt.NDArray
        Binary mask.

    Returns
    -------
    npt.NDArray
        A binary mask with only one object.
    """
    labelled_mask = label(mask)
    idxs, counts = np.unique(mask, return_counts=True)
    try:
        max_idx = idxs[np.argmax(counts[1:]) + 1]
        return np.where(labelled_mask == max_idx, 1, 0)
    except ValueError as e:
        LOGGER.debug(f"{e}: mask is empty.")
        return mask

lin_interp(point_1: list, point_2: list, xvalue: float | None = None, yvalue: float | None = None) -> float staticmethod

Linear interp 2 points by finding line equation and subbing.

Parameters:

Name Type Description Default
point_1 list

List of an x and y coordinate.

required
point_2 list

List of an x and y coordinate.

required
xvalue Union[float, None]

Value at which to interpolate to get a y coordinate, by default None.

None
yvalue Union[float, None]

Value at which to interpolate to get an x coordinate, by default None.

None

Returns:

Type Description
float

Value of x or y linear interpolation.

Source code in topostats\tracing\nodestats.py
@staticmethod
def lin_interp(point_1: list, point_2: list, xvalue: float | None = None, yvalue: float | None = None) -> float:
    """
    Linear interp 2 points by finding line equation and subbing.

    Parameters
    ----------
    point_1 : list
        List of an x and y coordinate.
    point_2 : list
        List of an x and y coordinate.
    xvalue : Union[float, None], optional
        Value at which to interpolate to get a y coordinate, by default None.
    yvalue : Union[float, None], optional
        Value at which to interpolate to get an x coordinate, by default None.

    Returns
    -------
    float
        Value of x or y linear interpolation.
    """
    m = (point_1[1] - point_2[1]) / (point_1[0] - point_2[0])
    c = point_1[1] - (m * point_1[0])
    if xvalue is not None:
        return m * xvalue + c  # interp_y
    if yvalue is not None:
        return (yvalue - c) / m  # interp_x
    raise ValueError

minimum_crossing_confs(node_dict: dict) -> None | float staticmethod

Return the minimum crossing confidence of all crossings in the molecule.

Parameters:

Name Type Description Default
node_dict dict

A dictionary containing node statistics and information.

required

Returns:

Type Description
Union[None, float]

The value of minimum confidence or none if not possible.

Source code in topostats\tracing\nodestats.py
@staticmethod
def minimum_crossing_confs(node_dict: dict) -> None | float:
    """
    Return the minimum crossing confidence of all crossings in the molecule.

    Parameters
    ----------
    node_dict : dict
        A dictionary containing node statistics and information.

    Returns
    -------
    Union[None, float]
        The value of minimum confidence or none if not possible.
    """
    confidences = []
    valid_confs = 0
    for _, (_, values) in enumerate(node_dict.items()):
        confidence = values["confidence"]
        if confidence is not None:
            confidences.append(confidence)
            valid_confs += 1
    try:
        return min(confidences)
    except ValueError:
        return None

only_centre_branches(node_image: npt.NDArray, node_coordinate: npt.NDArray) -> npt.NDArray[np.int32] staticmethod

Remove all branches not connected to the current node.

Parameters:

Name Type Description Default
node_image NDArray

An image of the skeletonised area surrounding the node where the background = 0, skeleton = 1, termini = 2, nodes = 3.

required
node_coordinate NDArray

2x1 coordinate describing the position of a node.

required

Returns:

Type Description
NDArray[int32]

The initial node image but only with skeletal branches connected to the middle node.

Source code in topostats\tracing\nodestats.py
@staticmethod
def only_centre_branches(node_image: npt.NDArray, node_coordinate: npt.NDArray) -> npt.NDArray[np.int32]:
    """
    Remove all branches not connected to the current node.

    Parameters
    ----------
    node_image : npt.NDArray
        An image of the skeletonised area surrounding the node where
        the background = 0, skeleton = 1, termini = 2, nodes = 3.
    node_coordinate : npt.NDArray
        2x1 coordinate describing the position of a node.

    Returns
    -------
    npt.NDArray[np.int32]
        The initial node image but only with skeletal branches
        connected to the middle node.
    """
    node_image_cp = node_image.copy()

    # get node-only image
    nodes = node_image_cp.copy()
    nodes[nodes != 3] = 0
    labeled_nodes = label(nodes)

    # find which cluster is closest to the centre
    node_coords = np.argwhere(nodes == 3)
    min_coords = node_coords[abs(node_coords - node_coordinate).sum(axis=1).argmin()]
    centre_idx = labeled_nodes[min_coords[0], min_coords[1]]

    # get nodeless image
    nodeless = node_image_cp.copy()
    nodeless = np.where(
        (node_image == 1) | (node_image == 2), 1, 0
    )  # if termini, need this in the labeled branches too
    nodeless[labeled_nodes == centre_idx] = 1  # return centre node
    labeled_nodeless = label(nodeless)

    # apply to return image
    for i in range(1, labeled_nodeless.max() + 1):
        if (node_image_cp[labeled_nodeless == i] == 3).any():
            node_image_cp[labeled_nodeless != i] = 0
            break

    # remove small area around other nodes
    labeled_nodes[labeled_nodes == centre_idx] = 0
    non_central_node_coords = np.argwhere(labeled_nodes != 0)
    for coord in non_central_node_coords:
        for j, coord_val in enumerate(coord):
            if coord_val - 1 < 0:
                coord[j] = 1
            if coord_val + 2 > node_image_cp.shape[j]:
                coord[j] = node_image_cp.shape[j] - 2
        node_image_cp[coord[0] - 1 : coord[0] + 2, coord[1] - 1 : coord[1] + 2] = 0

    return node_image_cp

order_branches(branch1: npt.NDArray, branch2: npt.NDArray) -> tuple staticmethod

Order the two ordered arrays based on the closest endpoint coordinates.

Parameters:

Name Type Description Default
branch1 NDArray

An Nx2 array describing coordinates.

required
branch2 NDArray

An Nx2 array describing coordinates.

required

Returns:

Type Description
tuple

An tuple with the each coordinate array ordered to follow on from one-another.

Source code in topostats\tracing\nodestats.py
@staticmethod
def order_branches(branch1: npt.NDArray, branch2: npt.NDArray) -> tuple:
    """
    Order the two ordered arrays based on the closest endpoint coordinates.

    Parameters
    ----------
    branch1 : npt.NDArray
        An Nx2 array describing coordinates.
    branch2 : npt.NDArray
        An Nx2 array describing coordinates.

    Returns
    -------
    tuple
        An tuple with the each coordinate array ordered to follow on from one-another.
    """
    endpoints1 = np.asarray([branch1[0], branch1[-1]])
    endpoints2 = np.asarray([branch2[0], branch2[-1]])
    sum1 = abs(endpoints1 - endpoints2).sum(axis=1)
    sum2 = abs(endpoints1[::-1] - endpoints2).sum(axis=1)
    if sum1.min() < sum2.min():
        if np.argmin(sum1) == 0:
            return branch1[::-1], branch2
        return branch1, branch2[::-1]
    if np.argmin(sum2) == 0:
        return branch1, branch2
    return branch1[::-1], branch2[::-1]

pair_angles(angles: npt.NDArray) -> list staticmethod

Pair angles that are 180 degrees to each other and removes them before selecting the next pair.

Parameters:

Name Type Description Default
angles NDArray

Square array (i,j) of angles between i and j.

required

Returns:

Type Description
list

A list of paired indexes in a list.

Source code in topostats\tracing\nodestats.py
@staticmethod
def pair_angles(angles: npt.NDArray) -> list:
    """
    Pair angles that are 180 degrees to each other and removes them before selecting the next pair.

    Parameters
    ----------
    angles : npt.NDArray
         Square array (i,j) of angles between i and j.

    Returns
    -------
    list
         A list of paired indexes in a list.
    """
    angles_cp = angles.copy()
    pairs = []
    for _ in range(int(angles.shape[0] / 2)):
        pair = np.unravel_index(np.argmax(angles_cp), angles.shape)
        pairs.append(pair)  # add to list
        angles_cp[[pair]] = 0  # set rows 0 to avoid picking again
        angles_cp[:, [pair]] = 0  # set cols 0 to avoid picking again

    return np.asarray(pairs)

pair_vectors(vectors: npt.NDArray) -> npt.NDArray[np.int32] staticmethod

Take a list of vectors and pairs them based on the angle between them.

Parameters:

Name Type Description Default
vectors NDArray

Array of 2x1 vectors to be paired.

required

Returns:

Type Description
NDArray

An array of the matching pair indices.

Source code in topostats\tracing\nodestats.py
@staticmethod
def pair_vectors(vectors: npt.NDArray) -> npt.NDArray[np.int32]:
    """
    Take a list of vectors and pairs them based on the angle between them.

    Parameters
    ----------
    vectors : npt.NDArray
        Array of 2x1 vectors to be paired.

    Returns
    -------
    npt.NDArray
        An array of the matching pair indices.
    """
    # calculate cosine of angle
    angles = nodeStats.calc_angles(vectors)
    # match angles
    return nodeStats.best_matches(angles)

recip(vals: list) -> float staticmethod

Compute 1 - (max / min) of the two values provided.

Parameters:

Name Type Description Default
vals list

List of 2 values.

required

Returns:

Type Description
float

Result of applying the 1-(min / max) function to the two values.

Source code in topostats\tracing\nodestats.py
@staticmethod
def recip(vals: list) -> float:
    """
    Compute 1 - (max / min) of the two values provided.

    Parameters
    ----------
    vals : list
        List of 2 values.

    Returns
    -------
    float
        Result of applying the 1-(min / max) function to the two values.
    """
    try:
        if min(vals) == 0:  # means fwhm variation hasn't worked
            return 0
        return 1 - min(vals) / max(vals)
    except ZeroDivisionError:
        return 0

skeleton_image_to_graph(skeleton: npt.NDArray) -> nx.classes.graph.Graph staticmethod

Convert a skeletonised mask into a Graph representation.

Graphs conserve the coordinates via the node label.

Parameters:

Name Type Description Default
skeleton NDArray

A binary single-pixel wide mask, or result from conv_skelly().

required

Returns:

Type Description
Graph

A networkX graph connecting the pixels in the skeleton to their neighbours.

Source code in topostats\tracing\nodestats.py
@staticmethod
def skeleton_image_to_graph(skeleton: npt.NDArray) -> nx.classes.graph.Graph:
    """
    Convert a skeletonised mask into a Graph representation.

    Graphs conserve the coordinates via the node label.

    Parameters
    ----------
    skeleton : npt.NDArray
        A binary single-pixel wide mask, or result from conv_skelly().

    Returns
    -------
    nx.classes.graph.Graph
        A networkX graph connecting the pixels in the skeleton to their neighbours.
    """
    skeImPos = np.argwhere(skeleton).T
    g = nx.Graph()
    neigh = np.array([[0, 1], [0, -1], [1, 0], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]])

    for idx in range(skeImPos[0].shape[0]):
        for neighIdx in range(neigh.shape[0]):
            curNeighPos = skeImPos[:, idx] + neigh[neighIdx]
            if np.any(curNeighPos < 0) or np.any(curNeighPos >= skeleton.shape):
                continue
            if skeleton[curNeighPos[0], curNeighPos[1]] > 0:
                idx_coord = skeImPos[0, idx], skeImPos[1, idx]
                curNeigh_coord = curNeighPos[0], curNeighPos[1]
                # assign lower weight to nodes if not a binary image
                if skeleton[idx_coord] == 3 and skeleton[curNeigh_coord] == 3:
                    weight = 0
                else:
                    weight = 1
                g.add_edge(idx_coord, curNeigh_coord, weight=weight)
    g.graph["physicalPos"] = skeImPos.T
    return g

tidy_branches(connect_node_mask: npt.NDArray, image: npt.NDArray) -> npt.NDArray

Wrangle distant connected nodes back towards the main cluster.

Works by filling and reskeletonising soely the node areas.

Parameters:

Name Type Description Default
connect_node_mask NDArray

The connected node mask - a skeleton where node regions = 3, endpoints = 2, and skeleton = 1.

required
image NDArray

The intensity image.

required

Returns:

Type Description
NDArray

The wrangled connected_node_mask.

Source code in topostats\tracing\nodestats.py
def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> npt.NDArray:
    """
    Wrangle distant connected nodes back towards the main cluster.

    Works by filling and reskeletonising soely the node areas.

    Parameters
    ----------
    connect_node_mask : npt.NDArray
        The connected node mask - a skeleton where node regions = 3, endpoints = 2, and skeleton = 1.
    image : npt.NDArray
        The intensity image.

    Returns
    -------
    npt.NDArray
        The wrangled connected_node_mask.
    """
    new_skeleton = np.where(connect_node_mask != 0, 1, 0)
    labeled_nodes = label(np.where(connect_node_mask == 3, 1, 0))
    for node_num in range(1, labeled_nodes.max() + 1):
        solo_node = np.where(labeled_nodes == node_num, 1, 0)
        coords = np.argwhere(solo_node == 1)
        node_centre = coords.mean(axis=0).astype(np.int32)
        node_wid = coords[:, 0].max() - coords[:, 0].min() + 2  # +2 so always 2 by default
        node_len = coords[:, 1].max() - coords[:, 1].min() + 2  # +2 so always 2 by default
        overflow = int(10 / self.pixel_to_nm_scaling) if int(10 / self.pixel_to_nm_scaling) != 0 else 1
        # grain mask fill
        new_skeleton[
            node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow,
            node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow,
        ] = self.mask[
            node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow,
            node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow,
        ]
    # remove any artifacts of the grain caught in the overflow areas
    new_skeleton = self.keep_biggest_object(new_skeleton)
    # Re-skeletonise
    new_skeleton = getSkeleton(image, new_skeleton, method="topostats", height_bias=0.6).get_skeleton()
    # new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton(
    #     {"method": "topostats", "max_length": -1}
    # )
    new_skeleton = prune_skeleton(
        image, new_skeleton, self.pixel_to_nm_scaling, **{"method": "topostats", "max_length": -1}
    )
    # cleanup around nibs
    new_skeleton = getSkeleton(image, new_skeleton, method="zhang").get_skeleton()
    # might also need to remove segments that have squares connected

    return convolve_skeleton(new_skeleton)

nodestats_image(image: npt.NDArray, disordered_tracing_direction_data: dict, filename: str, pixel_to_nm_scaling: float, node_joining_length: float, node_extend_dist: float, branch_pairing_length: float, pair_odd_branches: float, pad_width: int) -> tuple

Initialise the nodeStats class.

Parameters:

Name Type Description Default
image NDArray

The array of pixels.

required
disordered_tracing_direction_data dict

The images and bbox coordinates of the pruned skeletons.

required
filename str

The name of the file being processed. For logging purposes.

required
pixel_to_nm_scaling float

The pixel to nm scaling factor.

required
node_joining_length float

The length over which to join skeletal intersections to be counted as one crossing.

required
node_joining_length float

The distance over which to join nearby odd-branched nodes.

required
node_extend_dist float

The distance under which to join odd-branched node regions.

required
branch_pairing_length float

The length from the crossing point to pair and trace, obtaining FWHM's.

required
pair_odd_branches bool

Whether to try and pair odd-branched nodes.

required
pad_width int

The number of edge pixels to pad the image by.

required

Returns:

Type Description
tuple[dict, DataFrame, dict, dict]

The nodestats statistics for each crossing, crossing statistics to be added to the grain statistics, an image dictionary of nodestats steps for the entire image, and single grain images.

Source code in topostats\tracing\nodestats.py
def nodestats_image(
    image: npt.NDArray,
    disordered_tracing_direction_data: dict,
    filename: str,
    pixel_to_nm_scaling: float,
    node_joining_length: float,
    node_extend_dist: float,
    branch_pairing_length: float,
    pair_odd_branches: float,
    pad_width: int,
) -> tuple:
    """
    Initialise the nodeStats class.

    Parameters
    ----------
    image : npt.NDArray
        The array of pixels.
    disordered_tracing_direction_data : dict
        The images and bbox coordinates of the pruned skeletons.
    filename : str
        The name of the file being processed. For logging purposes.
    pixel_to_nm_scaling : float
        The pixel to nm scaling factor.
    node_joining_length : float
        The length over which to join skeletal intersections to be counted as one crossing.
    node_joining_length : float
        The distance over which to join nearby odd-branched nodes.
    node_extend_dist : float
        The distance under which to join odd-branched node regions.
    branch_pairing_length : float
        The length from the crossing point to pair and trace, obtaining FWHM's.
    pair_odd_branches : bool
        Whether to try and pair odd-branched nodes.
    pad_width : int
        The number of edge pixels to pad the image by.

    Returns
    -------
    tuple[dict, pd.DataFrame, dict, dict]
        The nodestats statistics for each crossing, crossing statistics to be added to the grain statistics,
        an image dictionary of nodestats steps for the entire image, and single grain images.
    """
    n_grains = len(disordered_tracing_direction_data)
    img_base = np.zeros_like(image)
    nodestats_data = {}

    # want to get each cropped image, use some anchor coords to match them onto the image,
    #   and compile all the grain images onto a single image
    all_images = {
        "convolved_skeletons": img_base.copy(),
        "node_centres": img_base.copy(),
        "connected_nodes": img_base.copy(),
    }
    nodestats_branch_images = {}
    grainstats_additions = {}

    LOGGER.info(f"[{filename}] : Calculating NodeStats statistics for {n_grains} grains...")

    for n_grain, disordered_tracing_grain_data in disordered_tracing_direction_data.items():
        nodestats = None  # reset the nodestats variable
        try:
            nodestats = nodeStats(
                image=disordered_tracing_grain_data["original_image"],
                mask=disordered_tracing_grain_data["original_grain"],
                smoothed_mask=disordered_tracing_grain_data["smoothed_grain"],
                skeleton=disordered_tracing_grain_data["pruned_skeleton"],
                pixel_to_nm_scaling=pixel_to_nm_scaling,
                filename=filename,
                n_grain=n_grain,
                node_joining_length=node_joining_length,
                node_extend_dist=node_extend_dist,
                branch_pairing_length=branch_pairing_length,
                pair_odd_branches=pair_odd_branches,
            )
            nodestats_dict, node_image_dict = nodestats.get_node_stats()
            LOGGER.debug(f"[{filename}] : Nodestats processed {n_grain} of {n_grains}")

            # compile images
            nodestats_images = {
                "convolved_skeletons": nodestats.conv_skelly,
                "node_centres": nodestats.node_centre_mask,
                "connected_nodes": nodestats.connected_nodes,
            }
            nodestats_branch_images[n_grain] = node_image_dict

            # compile metrics
            grainstats_additions[n_grain] = {
                "image": filename,
                "grain_number": int(n_grain.split("_")[-1]),
            }
            grainstats_additions[n_grain].update(nodestats.metrics)
            if nodestats_dict:  # if the grain's nodestats dict is not empty
                nodestats_data[n_grain] = nodestats_dict

            # remap the cropped images back onto the original
            for image_name, full_image in all_images.items():
                crop = nodestats_images[image_name]
                bbox = disordered_tracing_grain_data["bbox"]
                full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width]

        except Exception as e:  # pylint: disable=broad-exception-caught
            LOGGER.error(
                f"[{filename}] : Nodestats for {n_grain} failed. Consider raising an issue on GitHub. Error: ",
                exc_info=e,
            )
            nodestats_data[n_grain] = {}

        # turn the grainstats additions into a dataframe, # might need to do something for when everything is empty
        grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index")

    return nodestats_data, grainstats_additions_df, all_images, nodestats_branch_images