diff --git a/sc2reader/engine/plugins/context.py b/sc2reader/engine/plugins/context.py index ba8fccfe..3372d84b 100644 --- a/sc2reader/engine/plugins/context.py +++ b/sc2reader/engine/plugins/context.py @@ -14,6 +14,9 @@ def handleInitGame(self, event, replay): replay.units = set() replay.unit = dict() + # keep track of last TargetAbilityEvent for UpdateTargetAbilityEvent + self.last_target_ability_event = {} + def handleGameEvent(self, event, replay): self.load_message_game_player(event, replay) @@ -25,6 +28,10 @@ def handleAbilityEvent(self, event, replay): return if event.ability_id not in replay.datapack.abilities: + # safeguard against missing abilities + if event.player.pid in self.last_target_ability_event: + del self.last_target_ability_event[event.player.pid] + if not getattr(replay, 'marked_error', None): replay.marked_error = True event.logger.error(replay.filename) @@ -47,6 +54,8 @@ def handleAbilityEvent(self, event, replay): self.logger.error("Other unit {0} not found".format(event.other_unit_id)) def handleTargetAbilityEvent(self, event, replay): + self.last_target_ability_event[event.player.pid] = event + if not replay.datapack: return @@ -62,6 +71,17 @@ def handleTargetAbilityEvent(self, event, replay): event.target = unit replay.objects[event.target_unit_id] = unit + def handleUpdateTargetAbilityEvent(self, event, replay): + # We may not find a TargetAbilityEvent before finding an + # UpdateTargetAbilityEvent, perhaps due to Missing Abilities in the + # datapack + if event.player.pid in self.last_target_ability_event: + # store corresponding TargetAbilityEvent data in this event + # currently using for *MacroTracker only, so only need ability name + event.ability_name = self.last_target_ability_event[event.player.pid].ability_name + + self.handleTargetAbilityEvent(event, replay) + def handleSelectionEvent(self, event, replay): if not replay.datapack: return diff --git a/sc2reader/events/game.py b/sc2reader/events/game.py index a6ff2403..59996bf9 100644 --- a/sc2reader/events/game.py +++ b/sc2reader/events/game.py @@ -339,6 +339,22 @@ def __init__(self, frame, pid, data): self.location = (self.x, self.y, self.z) +class UpdateTargetAbilityEvent(TargetAbilityEvent): + """ + Extends :class:`TargetAbilityEvent` + + This event is generated when a TargetAbilityEvent is updated, likely due to + changing the target unit. It is unclear if this needs to be a separate event + from TargetAbilityEvent, but for flexibility, it will be treated + differently. + + One example of this event occuring is casting inject on a hatchery while + holding shift, and then shift clicking on a second hatchery. + """ + + name = 'UpdateTargetAbilityEvent' + + class SelfAbilityEvent(AbilityEvent): """ Extends :class:`AbilityEvent` diff --git a/sc2reader/readers.py b/sc2reader/readers.py index d8286c4a..d84f7122 100644 --- a/sc2reader/readers.py +++ b/sc2reader/readers.py @@ -1525,7 +1525,7 @@ def __init__(self): 61: (None, self.trigger_hotkey_pressed_event), 103: (None, self.command_manager_state_event), 104: (None, self.command_update_target_point_event), - 105: (None, self.command_update_target_unit_event), + 105: (UpdateTargetAbilityEvent, self.command_update_target_unit_event), 106: (None, self.trigger_anim_length_query_by_name_event), 107: (None, self.trigger_anim_length_query_by_props_event), 108: (None, self.trigger_anim_offset_event), @@ -1596,19 +1596,24 @@ def command_update_target_point_event(self, data): def command_update_target_unit_event(self, data): return dict( - target=dict( - target_unit_flags=data.read_uint16(), + flags=0, # fill me with previous TargetUnitEvent.flags + ability=None, # fill me with previous TargetUnitEvent.ability + data=('TargetUnit', dict( + flags=data.read_uint16(), timer=data.read_uint8(), - tag=data.read_uint32(), - snapshot_unit_link=data.read_uint16(), - snapshot_control_player_id=data.read_bits(4) if data.read_bool() else None, - snapshot_upkeep_player_id=data.read_bits(4) if data.read_bool() else None, - snapshot_point=dict( + unit_tag=data.read_uint32(), + unit_link=data.read_uint16(), + control_player_id=data.read_bits(4) if data.read_bool() else None, + upkeep_player_id=data.read_bits(4) if data.read_bool() else None, + point=dict( x=data.read_bits(20), y=data.read_bits(20), z=data.read_bits(32) - 2147483648, - ) - ) + ), + )), + sequence=0, # fill me with previous TargetUnitEvent.flags + other_unit_tag=None, # fill me with previous TargetUnitEvent.flags + unit_group=None, # fill me with previous TargetUnitEvent.flags ) def command_event(self, data): diff --git a/test_replays/3.3.0/ggissue48.SC2Replay b/test_replays/3.3.0/ggissue48.SC2Replay new file mode 100644 index 00000000..57f06b56 Binary files /dev/null and b/test_replays/3.3.0/ggissue48.SC2Replay differ diff --git a/test_replays/3.3.0/ggissue49.SC2Replay b/test_replays/3.3.0/ggissue49.SC2Replay new file mode 100644 index 00000000..d64f4458 Binary files /dev/null and b/test_replays/3.3.0/ggissue49.SC2Replay differ diff --git a/test_replays/test_all.py b/test_replays/test_all.py index 61d5b69e..94b58a66 100644 --- a/test_replays/test_all.py +++ b/test_replays/test_all.py @@ -495,12 +495,24 @@ def test_33(self): replay = sc2reader.load_replay("test_replays/3.3.0/{}.SC2Replay".format(replaynum)) self.assertTrue(replay is not None) + def test_33_shift_click_calldown_mule(self): + replay = sc2reader.load_replay("test_replays/3.3.0/ggissue48.SC2Replay") + def efilter(e): + return hasattr(e, "ability") and e.ability_name == "CalldownMULE" + self.assertEqual(len(filter(efilter, replay.events)), 29) + + def test_33_shift_click_spawn_larva(self): + replay = sc2reader.load_replay("test_replays/3.3.0/ggissue49.SC2Replay") + def efilter(e): + return hasattr(e, "ability") and e.ability_name == "SpawnLarva" + self.assertEqual(len(filter(efilter, replay.events)), 23) + def test_lotv_time(self): - replay = sc2reader.load_replay("test_replays/lotv/lotv1.SC2Replay") - self.assertEqual(replay.length.seconds, 1002) - self.assertEqual(replay.real_length.seconds, 1002) + replay = sc2reader.load_replay("test_replays/lotv/lotv1.SC2Replay") + self.assertEqual(replay.length.seconds, 1002) + self.assertEqual(replay.real_length.seconds, 1002) + - class TestGameEngine(unittest.TestCase): class TestEvent(object): name='TestEvent'