Skip to content

Commit 386b20a

Browse files
committed
Updates ByteStream class to leverage struct.unpack and StringIO package for increased parse speed (hopefully). Also updates the rest of the code base to reflect interface changes
1 parent f29d60f commit 386b20a

File tree

4 files changed

+260
-87
lines changed

4 files changed

+260
-87
lines changed

sc2reader/eventparsers.py

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class AbilityEventParser(object):
3737
def load(self, event, bytes):
3838
event.bytes += bytes.skip(4, byte_code=True)
3939
event.bytes += bytes.peek(4)
40-
event.ability = bytes.get_big_int(1) << 16 | bytes.get_big_int(1) << 8 | bytes.get_big_int(1)
41-
req_target = bytes.get_big_int(1)
40+
event.ability = bytes.get_big_8() << 16 | bytes.get_big_8() << 8 | bytes.get_big_8()
41+
req_target = bytes.get_big_8()
4242

4343
#In certain cases we need an extra byte
4444
if req_target == 0x30 or req_target == 0x05:
@@ -49,8 +49,8 @@ def load(self, event, bytes):
4949
class AbilityEventParser_16561(AbilityEventParser):
5050
def load(self, event, bytes):
5151
event.bytes += bytes.peek(5)
52-
first, atype = (bytes.get_big_int(1), bytes.get_big_int(1))
53-
event.ability = bytes.get_big_int(1) << 16 | bytes.get_big_int(1) << 8 | (bytes.get_big_int(1) & 0x3F)
52+
first, atype = (bytes.get_big_8(), bytes.get_big_8())
53+
event.ability = bytes.get_big_8() << 16 | bytes.get_big_8() << 8 | (bytes.get_big_8() & 0x3F)
5454

5555
if event.ability in abilities:
5656
event.abilitystr = abilities[event.ability]
@@ -90,7 +90,7 @@ class SelectionEventParser(object):
9090
def load(self, event, bytes):
9191
event.name = 'selection'
9292
event.bytes += bytes.peek(2)
93-
select_flags, deselect_count = bytes.get_big_int(1), bytes.get_big_int(1)
93+
select_flags, deselect_count = bytes.get_big_8(), bytes.get_big_8()
9494

9595
if deselect_count > 0:
9696
#Skip over the deselection bytes
@@ -100,17 +100,18 @@ def load(self, event, bytes):
100100
extras = deselect_count & 0x07
101101
if extras == 0:
102102
#This is easy when we are byte aligned
103-
unit_type_count, byte = bytes.get_big_int(1, byte_code=True)
103+
unit_type_count, byte = bytes.get_big_8(byte_code=True)
104104
event.bytes += byte
105105

106106
event.bytes += bytes.peek(unit_type_count*4+1)
107107
for i in range(0, unit_type_count):
108-
unit_type, unit_count = bytes.get_big_int(3), bytes.get_big_int(1)
109-
totalUnits = bytes.get_big_int(1)
108+
unit_type_block = bytes.get_big_32()
109+
unit_type, unit_count = unit_type_block >> 8, unit_type_block & 0xFF
110+
totalUnits = bytes.get_big_8()
110111

111112
event.bytes += bytes.peek(totalUnits*4)
112113
for i in range(0, totalUnits):
113-
unit_id, use_count = bytes.get_big_int(2), bytes.get_big_int(2)
114+
unit_id, use_count = bytes.get_big_16(), bytes.get_big_16()
114115
else:
115116
#We're not byte aligned, so need do so some bit shifting
116117
#This seems like 1000% wrong to me, but its what the people
@@ -121,40 +122,40 @@ def load(self, event, bytes):
121122
w_head_mask = ~w_tail_mask & 0xFF
122123

123124
event.bytes += bytes.peek(2)
124-
prev_byte, next_byte = bytes.get_big_int(1), bytes.get_big_int(1)
125+
prev_byte, next_byte = bytes.get_big_8(), bytes.get_big_8()
125126

126127
unit_type_count = prev_byte & head_mask | next_byte & tail_mask
127128

128129
event.bytes += bytes.peek(unit_type_count*4+1)
129130
for i in range(0, unit_type_count):
130-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
131+
prev_byte, next_byte = next_byte, bytes.get_big_8()
131132
unit_type = prev_byte & head_mask | ((next_byte & w_head_mask) >> (8-extras))
132-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
133+
prev_byte, next_byte = next_byte, bytes.get_big_8()
133134
unit_type = unit_type << 8 | (prev_byte & w_tail_mask) << extras | next_byte & tail_mask
134-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
135+
prev_byte, next_byte = next_byte, bytes.get_big_8()
135136
unit_type = unit_type << 8 | (prev_byte & head_mask) << extras | next_byte & tail_mask
136-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
137+
prev_byte, next_byte = next_byte, bytes.get_big_8()
137138
unit_count = prev_byte & head_mask | next_byte & tail_mask
138139

139-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
140+
prev_byte, next_byte = next_byte, bytes.get_big_8()
140141
totalUnits = prev_byte & head_mask | next_byte & tail_mask
141142

142143
event.bytes = bytes.peek(totalUnits*4)
143144
for i in range(0, totalUnits):
144-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
145+
prev_byte, next_byte = next_byte, bytes.get_big_8()
145146
unit_id = prev_byte & head_mask | ((next_byte & w_head_mask) >> (8-extras))
146-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
147+
prev_byte, next_byte = next_byte, bytes.get_big_8()
147148
unit_id = unit_id << 8 | prev_byte & w_tail_mask << extras | ((next_byte & w_head_mask) >> (8-extras))
148-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
149+
prev_byte, next_byte = next_byte, bytes.get_big_8()
149150
unit_id = unit_id << 8 | prev_byte & w_tail_mask << extras | ((next_byte & w_head_mask) >> (8-extras))
150-
prev_byte, next_byte = next_byte, bytes.get_big_int(1)
151+
prev_byte, next_byte = next_byte, bytes.get_big_8()
151152
unit_id = unit_id << 8 | prev_byte & w_tail_mask << extras | next_byte & tail_mask
152153

153154
class SelectionEventParser_16561(SelectionEventParser):
154155
def load(self, event, bytes):
155156
event.name = 'selection'
156157
event.bytes += bytes.peek(2)
157-
select_flags, deselect_type = bytes.get_big_int(1), bytes.get_big_int(1)
158+
select_flags, deselect_type = bytes.get_big_8(), bytes.get_big_8()
158159

159160
#No deselection to do here
160161
if deselect_type & 3 == 0:
@@ -164,7 +165,7 @@ def load(self, event, bytes):
164165
#deselection by bit counted indicators
165166
elif deselect_type & 3 == 1:
166167
#use the 6 left bits on top and the 2 right bits on bottom
167-
count_byte, byte = bytes.get_big_int(1, byte_code=True)
168+
count_byte, byte = bytes.get_big_8(byte_code=True)
168169
deselect_count = deselect_type & 0xFC | count_byte & 0x03
169170
event.bytes += byte
170171

@@ -175,7 +176,7 @@ def load(self, event, bytes):
175176
#while count > 6 we need to eat into more bytes because
176177
#we only have 6 bits left in our current byte
177178
while deselect_count > 6:
178-
last_byte, byte = bytes.get_big_int(1, byte_code=True)
179+
last_byte, byte = bytes.get_big_8(byte_code=True)
179180
deselect_count -= 8
180181
event.bytes += byte
181182

@@ -193,7 +194,7 @@ def load(self, event, bytes):
193194
#and as such probably has a deselect_count always == 0, not sure though
194195
else:
195196
#use the 6 left bits on top and the 2 right bits on bottom
196-
count_byte, byte = bytes.get_big_int(1, byte_code=True)
197+
count_byte, byte = bytes.get_big_8(byte_code=True)
197198
deselect_count = deselect_type & 0xFC | count_byte & 0x03
198199
event.bytes += byte
199200

@@ -206,7 +207,7 @@ def load(self, event, bytes):
206207
else:
207208
event.bytes += bytes.peek(deselect_count)
208209
bytes.skip(deselect_count-1)
209-
last_byte = bytes.get_big_int(1)
210+
last_byte = bytes.get_big_8()
210211

211212
mask = 0x03 #default mask of '11' applies
212213

@@ -217,7 +218,7 @@ def load(self, event, bytes):
217218

218219
#Get the number of selected unit types
219220
event.bytes += bytes.peek(1)
220-
next_byte = bytes.get_big_int(1)
221+
next_byte = bytes.get_big_8()
221222
numunit_types = combine(last_byte, next_byte)
222223

223224
#Read them all into a dictionary for later
@@ -229,20 +230,20 @@ def load(self, event, bytes):
229230
byte_list = list()
230231
for i in range(0, 3):
231232
#Swap the bytes, grab another, and combine w/ the mask
232-
last_byte, next_byte = next_byte, bytes.get_big_int(1)
233+
last_byte, next_byte = next_byte, bytes.get_big_8()
233234
byte_list.append( combine(last_byte, next_byte) )
234235
unit_type_id = byte_list[0] << 16 | byte_list[1] << 8 | byte_list[2]
235236

236237
#Get the count for that type in the next byte
237-
last_byte, next_byte = next_byte, bytes.get_big_int(1)
238+
last_byte, next_byte = next_byte, bytes.get_big_8()
238239
unit_type_count = combine(last_byte, next_byte)
239240

240241
#Store for later
241242
unit_types[unit_type_id] = unit_type_count
242243

243244
#Get total unit count
244245
event.bytes += bytes.peek(1)
245-
last_byte, next_byte = next_byte, bytes.get_big_int(1)
246+
last_byte, next_byte = next_byte, bytes.get_big_8()
246247
unit_count = combine(last_byte, next_byte)
247248

248249
#Pull all the unit_ids in for later
@@ -252,7 +253,7 @@ def load(self, event, bytes):
252253
#build the unit_id over the next 4 bytes
253254
byte_list = list()
254255
for i in range(0, 4):
255-
last_byte, next_byte = next_byte, bytes.get_big_int(1)
256+
last_byte, next_byte = next_byte, bytes.get_big_8()
256257
byte_list.append( combine(last_byte, next_byte) )
257258

258259
#The first 2 bytes are unique and the last 2 mark reusage count
@@ -283,7 +284,7 @@ def load_get_hotkey_changed(self, event, bytes, first):
283284

284285
extras = first >> 3
285286
event.bytes += bytes.peek(extras+1)
286-
second = bytes.get_big_int(1)
287+
second = bytes.get_big_8()
287288
bytes.skip(extras)
288289

289290
if first & 0x04:
@@ -295,7 +296,7 @@ def load(self, event, bytes):
295296
event.name = 'hotkey'
296297
event.hotkey = str(event.code >> 4)
297298
#print "Time %s - Player %s is using hotkey %s" % (self.timestr, self.player, eventCode >> 4)
298-
first, byte = bytes.get_big_int(1, byte_code=True)
299+
first, byte = bytes.get_big_8(byte_code=True)
299300
event.bytes += byte
300301

301302
if first == 0x00: self.load_set_hotkey(event, bytes, first)
@@ -308,7 +309,7 @@ def load(self, event, bytes):
308309
class HotkeyEventParser_16561(HotkeyEventParser):
309310
def load_get_hotkey_changed(self, event, bytes, first):
310311
name = 'get_hotkey_changed'
311-
second, byte = bytes.get_big_int(1, byte_code=True)
312+
second, byte = bytes.get_big_8(byte_code=True)
312313
event.bytes += byte
313314

314315
if first & 0x08:
@@ -339,8 +340,8 @@ def load(self, event, bytes):
339340

340341
#I might need to shift these two things to 19, 11, 3 for first 3 shifts
341342
event.bytes += bytes.peek(8)
342-
event.minerals = bytes.get_big_int(1) << 20 | bytes.get_big_int(1) << 12 | bytes.get_big_int(1) << 4 | bytes.get_big_int(1) >> 4
343-
event.gas = bytes.get_big_int(1) << 20 | bytes.get_big_int(1) << 12 | bytes.get_big_int(1) << 4 | bytes.get_big_int(1) >> 4
343+
event.minerals = bytes.get_big_8() << 20 | bytes.get_big_8() << 12 | bytes.get_big_8() << 4 | bytes.get_big_8() >> 4
344+
event.gas = bytes.get_big_8() << 20 | bytes.get_big_8() << 12 | bytes.get_big_8() << 4 | bytes.get_big_8() >> 4
344345

345346
#unknown extra stuff
346347
event.bytes += bytes.skip(2, byte_code=True)
@@ -355,12 +356,14 @@ def load(self, event, bytes):
355356
event.sender = event.player
356357
event.reciever = event.code >> 4
357358

358-
bytes.get_big_int(1) #Always 84
359+
bytes.get_big_8() #Always 84
359360

360361
#Minerals and Gas are encoded the same way
361-
base, extension = bytes.get_big_int(3), bytes.get_big_int(1)
362+
resource_block = bytes.get_big_32()
363+
base, extension = resource_block >> 8, resource_block & 0xFF
362364
event.minerals = base*(extension >> 4)+ (extension & 0x0F)
363-
base, extension = bytes.get_big_int(3), bytes.get_big_int(1)
365+
resource_block = bytes.get_big_32()
366+
base, extension = resource_block >> 8, resource_block & 0xFF
364367
event.gas = base*(extension >> 4)+ (extension & 0x0F)
365368

366369
#Another 8 bytes that don't make sense
@@ -395,17 +398,17 @@ def load(self, event, bytes):
395398
event.name = 'cameramovement_X1'
396399
#Get the X and Y, last byte is also a flag
397400
event.bytes += bytes.skip(3, byte_code=True)+bytes.peek(1)
398-
flag = bytes.get_big_int(1)
401+
flag = bytes.get_big_8()
399402

400403
#Get the zoom, last byte is a flag
401404
if flag & 0x10 != 0:
402405
event.bytes += bytes.skip(1, byte_code=True)+bytes.peek(1)
403-
flag = bytes.get_big_int(1)
406+
flag = bytes.get_big_8()
404407

405408
#If we are currently zooming get more?? idk
406409
if flag & 0x20 != 0:
407410
event.bytes += bytes.skip(1, byte_code=True)+bytes.peek(1)
408-
flag = bytes.get_big_int(1)
411+
flag = bytes.get_big_8()
409412

410413
#Do camera rotation as applies
411414
if flag & 0x40 != 0:
@@ -449,10 +452,10 @@ class UnknownEventParser_04C6(object):
449452
def load(self, event, bytes):
450453
event.name = 'unknown04C6'
451454
event.bytes += bytes.peek(16)
452-
block1 = bytes.get_big(4)
453-
block2 = bytes.get_big(4)
454-
block3 = bytes.get_big(4)
455-
block4 = bytes.get_big(4)
455+
block1 = bytes.get_big_32()
456+
block2 = bytes.get_big_32()
457+
block3 = bytes.get_big_32()
458+
block4 = bytes.get_big_32()
456459
return event
457460

458461
class UnknownEventParser_041C(object):
@@ -464,7 +467,7 @@ def load(self, event, bytes):
464467
class UnknownEventParser_0487(object):
465468
def load(self, event, bytes):
466469
event.name = 'unknown0418-87'
467-
event.data, databytes = bytes.get_big(4, byte_code=True) #Always 00 00 00 01??
470+
event.data, databytes = bytes.get_big_32(byte_code=True) #Always 00 00 00 01??
468471
event.bytes += databytes
469472
return event
470473

@@ -487,5 +490,4 @@ class UnknownEventParser_0589(object):
487490
def load(self, event, bytes):
488491
event.name = 'unknown0589'
489492
event.bytes += bytes.skip(4, byte_code=True)
490-
return event
491-
493+
return event

sc2reader/parsers.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,38 +42,38 @@ def get_initdata_parser(build):
4242
class InitdataParser(object):
4343
def load(self,replay,filecontents):
4444
bytes = ByteStream(filecontents)
45-
num_players = bytes.get_big_int(1)
45+
num_players = bytes.get_big_8()
4646
for p in range(0,num_players):
47-
name_length = bytes.get_big_int(1)
47+
name_length = bytes.get_big_8()
4848
name = bytes.get_string(name_length)
4949
bytes.skip(5)
5050

5151
bytes.skip(5) # Unknown
5252
bytes.get_string(4) # Always Dflt
5353
bytes.skip(15) #Unknown
54-
id_length = bytes.get_big_int(1)
54+
id_length = bytes.get_big_8()
5555
sc_account_id = bytes.get_string(id_length)
5656
bytes.skip(684) # Fixed Length data for unknown purpose
5757
while( bytes.get_string(4).lower() == 's2ma' ):
5858
bytes.skip(2)
5959
replay.realm = bytes.get_string(2).lower()
60-
unknown_map_hash = bytes.get_big(32)
60+
unknown_map_hash = bytes.get_bytes(32)
6161

6262
#################################################
6363
# replay.attributes.events Parsing classes
6464
#################################################
6565
class AttributeParser(object):
6666
def load_header(self, replay, bytes):
67-
bytes.skip(4, byte_code=True) #Always start with 4 nulls
68-
self.count = bytes.get_little_int(4) #get total attribute count
67+
bytes.skip(4) #Always start with 4 nulls
68+
self.count = bytes.get_little_32() #get total attribute count
6969

7070
def load_attribute(self, replay, bytes):
7171
#Get the attribute data elements
7272
attr_data = [
73-
bytes.get_little_int(4), #Header
74-
bytes.get_little_int(4), #Attr Id
75-
bytes.get_little_int(1), #Player
76-
bytes.get_little(4) #Value
73+
bytes.get_little_32(), #Header
74+
bytes.get_little_32(), #Attr Id
75+
bytes.get_little_8(), #Player
76+
bytes.get_little_bytes(4).encode("hex") #Value
7777
]
7878

7979
#Complete the decoding in the attribute object
@@ -121,8 +121,8 @@ def load(self, replay, filecontents):
121121

122122
class AttributeParser_17326(AttributeParser):
123123
def load_header(self, replay, bytes):
124-
bytes.skip(5, byte_code=True) #Always start with 4 nulls
125-
self.count = bytes.get_little_int(4) #get total attribute count
124+
bytes.skip(5) #Always start with 4 nulls
125+
self.count = bytes.get_little_32() #get total attribute count
126126

127127
##################################################
128128
# replay.details parsing classes
@@ -150,8 +150,8 @@ def load(self, replay, filecontents):
150150

151151
while(bytes.remaining!=0):
152152
time += bytes.get_timestamp()
153-
player_id = bytes.get_big_int(1) & 0x0F
154-
flags = bytes.get_big_int(1)
153+
player_id = bytes.get_big_8() & 0x0F
154+
flags = bytes.get_big_8()
155155

156156
if flags & 0xF0 == 0x80:
157157

@@ -169,7 +169,7 @@ def load(self, replay, filecontents):
169169

170170
elif flags & 0x80 == 0:
171171
target = flags & 0x03
172-
length = bytes.get_big_int(1)
172+
length = bytes.get_big_8()
173173

174174
if flags & 0x08:
175175
length += 64
@@ -238,7 +238,7 @@ def load(self, replay, filecontents):
238238
#event_type, the 4th bit 000X0000 marks the eventObjectas local or global,
239239
#and the remaining bits 0000XXXX mark the player id number.
240240
#The following byte completes the unique eventObjectidentifier
241-
first, event_code = bytes.get_big_int(1), bytes.get_big_int(1)
241+
first, event_code = bytes.get_big_8(), bytes.get_big_8()
242242
event_type, global_flag, player_id = first >> 5, first & 0x10, first & 0xF
243243

244244
#Create a barebones event from the gathered information

0 commit comments

Comments
 (0)