Skip to content

Commit c626528

Browse files
committed
add test for network
1 parent cb87f3d commit c626528

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/py_eddy_tracker/observations/network.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,14 +1104,28 @@ def get_group_array(self, results, nb_obs):
11041104
if g0 > g1:
11051105
g0, g1 = g1, g0
11061106
merge_id.append((g0, g1))
1107+
gr_transfer = self.group_translator(id_free, set(merge_id))
1108+
return gr_transfer[gr]
1109+
1110+
@staticmethod
1111+
def group_translator(nb, duos):
1112+
"""
1113+
Create a translator with all duos
11071114
1108-
# FIXME: how it's work when several merge ? like (0,1), (0,2), (1,3)
1109-
gr_transfer = arange(id_free, dtype="u4")
1110-
for i, j in set(merge_id):
1111-
gr_i, gr_j = gr_transfer[i], gr_transfer[j]
1115+
:param int nb: size of translator
1116+
:param set((int, int)) duos: set of all group which must be join
1117+
1118+
Examples
1119+
--------
1120+
>>> NetworkObservations.group_translator(5, ((0, 1), (0, 2), (1, 3)))
1121+
[3, 3, 3, 3, 5]
1122+
"""
1123+
translate = arange(nb, dtype="u4")
1124+
for i, j in sorted(duos):
1125+
gr_i, gr_j = translate[i], translate[j]
11121126
if gr_i != gr_j:
1113-
apply_replace(gr_transfer, gr_i, gr_j)
1114-
return gr_transfer[gr]
1127+
apply_replace(translate, gr_i, gr_j)
1128+
return translate
11151129

11161130
def group_observations(self, **kwargs):
11171131
results, nb_obs = list(), list()

tests/test_network.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from py_eddy_tracker.observations.network import Network
2+
3+
4+
def test_group_translate():
5+
translate = Network.group_translator(5, ((0, 1), (0, 2), (1, 3)))
6+
assert (translate == [3, 3, 3, 3, 4]).all()
7+
8+
translate = Network.group_translator(5, ((1, 3), (0, 1), (0, 2)))
9+
assert (translate == [3, 3, 3, 3, 4]).all()
10+
11+
translate = Network.group_translator(8, ((1, 3), (2, 3), (2, 4), (5, 6), (4, 5)))
12+
assert (translate == [0, 6, 6, 6, 6, 6, 6, 7]).all()
13+
14+
translate = Network.group_translator(6, ((0, 1), (0, 2), (1, 3), (4, 5)))
15+
assert (translate == [3, 3, 3, 3, 5, 5]).all()

0 commit comments

Comments
 (0)