|  | 
| 12 | 12 | from re import compile as re_compile | 
| 13 | 13 | 
 | 
| 14 | 14 | from netCDF4 import Dataset | 
| 15 |  | -from numpy import bytes_, empty, unique | 
|  | 15 | +from numpy import bincount, bytes_, empty, in1d, unique | 
| 16 | 16 | from yaml import safe_load | 
| 17 | 17 | 
 | 
| 18 | 18 | from .. import EddyParser | 
| 19 |  | -from ..observations.observation import EddiesObservations | 
|  | 19 | +from ..observations.observation import EddiesObservations, reverse_index | 
| 20 | 20 | from ..observations.tracking import TrackEddiesObservations | 
| 21 | 21 | from ..tracking import Correspondances | 
| 22 | 22 | 
 | 
| @@ -373,3 +373,137 @@ def track( | 
| 373 | 373 |     short_track.write_file( | 
| 374 | 374 |         filename="%(path)s/%(sign_type)s_track_too_short.nc", **kw_write | 
| 375 | 375 |     ) | 
|  | 376 | + | 
|  | 377 | + | 
|  | 378 | +def get_group( | 
|  | 379 | +    dataset1, | 
|  | 380 | +    dataset2, | 
|  | 381 | +    index1, | 
|  | 382 | +    index2, | 
|  | 383 | +    score, | 
|  | 384 | +    invalid=2, | 
|  | 385 | +    low=10, | 
|  | 386 | +    high=60, | 
|  | 387 | +): | 
|  | 388 | +    group1, group2 = dict(), dict() | 
|  | 389 | +    m_valid = (score * 100) >= invalid | 
|  | 390 | +    i1, i2, score = index1[m_valid], index2[m_valid], score[m_valid] * 100 | 
|  | 391 | +    # Eddies with no association & scores < invalid | 
|  | 392 | +    group1["nomatch"] = reverse_index(i1, len(dataset1)) | 
|  | 393 | +    group2["nomatch"] = reverse_index(i2, len(dataset2)) | 
|  | 394 | +    # Select all eddies involved in multiple associations | 
|  | 395 | +    i1_, nb1 = unique(i1, return_counts=True) | 
|  | 396 | +    i2_, nb2 = unique(i2, return_counts=True) | 
|  | 397 | +    i1_multi = i1_[nb1 >= 2] | 
|  | 398 | +    i2_multi = i2_[nb2 >= 2] | 
|  | 399 | +    m_multi = in1d(i1, i1_multi) + in1d(i2, i2_multi) | 
|  | 400 | +    group1["multi_match"] = unique(i1[m_multi]) | 
|  | 401 | +    group2["multi_match"] = unique(i2[m_multi]) | 
|  | 402 | + | 
|  | 403 | +    # Low scores | 
|  | 404 | +    m_low = score <= low | 
|  | 405 | +    m_low *= ~m_multi | 
|  | 406 | +    group1["low"] = i1[m_low] | 
|  | 407 | +    group2["low"] = i2[m_low] | 
|  | 408 | +    # Intermediate scores | 
|  | 409 | +    m_i = (score > low) * (score <= high) | 
|  | 410 | +    m_i *= ~m_multi | 
|  | 411 | +    group1["intermediate"] = i1[m_i] | 
|  | 412 | +    group2["intermediate"] = i2[m_i] | 
|  | 413 | +    # High scores | 
|  | 414 | +    m_high = score > high | 
|  | 415 | +    m_high *= ~m_multi | 
|  | 416 | +    group1["high"] = i1[m_high] | 
|  | 417 | +    group2["high"] = i2[m_high] | 
|  | 418 | + | 
|  | 419 | +    def get_twin(j2, j1): | 
|  | 420 | +        # True only if j1 is used only one | 
|  | 421 | +        m = bincount(j1)[j1] == 1 | 
|  | 422 | +        # We keep only link of this mask j1 have exactly one parent | 
|  | 423 | +        j2_ = j2[m] | 
|  | 424 | +        # We count parent times | 
|  | 425 | +        m_ = (bincount(j2_)[j2_] == 2) * (bincount(j2)[j2_] == 2) | 
|  | 426 | +        # we fill first mask with second one | 
|  | 427 | +        m[m] = m_ | 
|  | 428 | +        return m | 
|  | 429 | + | 
|  | 430 | +    m1 = get_twin(i1, i2) | 
|  | 431 | +    m2 = get_twin(i2, i1) | 
|  | 432 | +    group1["parent"] = unique(i1[m1]) | 
|  | 433 | +    group2["parent"] = unique(i2[m2]) | 
|  | 434 | +    group1["twin"] = i1[m2] | 
|  | 435 | +    group2["twin"] = i2[m1] | 
|  | 436 | + | 
|  | 437 | +    m = ~m1 * ~m2 * m_multi | 
|  | 438 | +    group1["complex"] = unique(i1[m]) | 
|  | 439 | +    group2["complex"] = unique(i2[m]) | 
|  | 440 | + | 
|  | 441 | +    return group1, group2 | 
|  | 442 | + | 
|  | 443 | + | 
|  | 444 | +def quick_compare(): | 
|  | 445 | +    parser = EddyParser( | 
|  | 446 | +        "Tool to have a quick comparison between several identification" | 
|  | 447 | +    ) | 
|  | 448 | +    parser.add_argument("ref", help="Identification file of reference") | 
|  | 449 | +    parser.add_argument("others", nargs="+", help="Identifications files to compare") | 
|  | 450 | +    parser.add_argument("--high", default=40, type=float) | 
|  | 451 | +    parser.add_argument("--low", default=20, type=float) | 
|  | 452 | +    parser.add_argument("--invalid", default=5, type=float) | 
|  | 453 | +    parser.contour_intern_arg() | 
|  | 454 | +    args = parser.parse_args() | 
|  | 455 | + | 
|  | 456 | +    kw = dict( | 
|  | 457 | +        include_vars=[ | 
|  | 458 | +            "longitude", | 
|  | 459 | +            *EddiesObservations.intern(args.intern, public_label=True), | 
|  | 460 | +        ] | 
|  | 461 | +    ) | 
|  | 462 | + | 
|  | 463 | +    ref = EddiesObservations.load_file(args.ref, **kw) | 
|  | 464 | +    print(f"[ref] {args.ref} -> {len(ref)} obs") | 
|  | 465 | +    groups_ref, groups_other = dict(), dict() | 
|  | 466 | +    others = {other: EddiesObservations.load_file(other, **kw) for other in args.others} | 
|  | 467 | +    for i, other_ in enumerate(args.others): | 
|  | 468 | +        other = others[other_] | 
|  | 469 | +        print(f"[{i}] {other_} -> {len(other)} obs") | 
|  | 470 | +        gr1, gr2 = get_group( | 
|  | 471 | +            ref, | 
|  | 472 | +            other, | 
|  | 473 | +            *ref.match(other, intern=args.intern), | 
|  | 474 | +            invalid=args.invalid, | 
|  | 475 | +            low=args.low, | 
|  | 476 | +            high=args.high, | 
|  | 477 | +        ) | 
|  | 478 | +        groups_ref[other_] = gr1 | 
|  | 479 | +        groups_other[other_] = gr2 | 
|  | 480 | + | 
|  | 481 | +    def display(value, ref=None): | 
|  | 482 | +        outs = list() | 
|  | 483 | +        for v in value: | 
|  | 484 | +            if ref: | 
|  | 485 | +                outs.append(f"{v/ref * 100:.1f}% ({v})") | 
|  | 486 | +            else: | 
|  | 487 | +                outs.append(v) | 
|  | 488 | +        return "".join([f"{v:^15}" for v in outs]) | 
|  | 489 | + | 
|  | 490 | +    keys = list(gr1.keys()) | 
|  | 491 | +    print("     ", display(keys)) | 
|  | 492 | +    for i, v in enumerate(groups_ref.values()): | 
|  | 493 | +        print( | 
|  | 494 | +            f"[{i:2}] ", | 
|  | 495 | +            display( | 
|  | 496 | +                (v_.sum() if v_.dtype == "bool" else v_.shape[0] for v_ in v.values()), | 
|  | 497 | +                ref=len(ref), | 
|  | 498 | +            ), | 
|  | 499 | +        ) | 
|  | 500 | + | 
|  | 501 | +    print(display(keys)) | 
|  | 502 | +    for i, (k, v) in enumerate(groups_other.items()): | 
|  | 503 | +        print( | 
|  | 504 | +            f"[{i:2}] ", | 
|  | 505 | +            display( | 
|  | 506 | +                (v_.sum() if v_.dtype == "bool" else v_.shape[0] for v_ in v.values()), | 
|  | 507 | +                ref=len(others[k]), | 
|  | 508 | +            ), | 
|  | 509 | +        ) | 
0 commit comments