Skip to content

Commit 7e8d4a5

Browse files
committed
Small ReplayBuffer.read optimization.
Pre-generate and store the state for reads (masks in particular). Looking up the state is actually faster the regenerating it. I guess because python is slow with bit level operations.
1 parent 4289ba0 commit 7e8d4a5

File tree

1 file changed

+45
-24
lines changed

1 file changed

+45
-24
lines changed

sc2reader/utils.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import struct
77
import textwrap
88

9-
import sc2reader.exceptions
10-
119
from itertools import groupby
1210

11+
from .exceptions import FileError
12+
1313
LITTLE_ENDIAN,BIG_ENDIAN = '<','>'
1414

1515
class ReplayBuffer(object):
@@ -80,6 +80,34 @@ def __init__(self, file):
8080
self.read_basic = self.io.read
8181
self.char_buffer = cStringIO.StringIO()
8282

83+
# Pre-generate the state for all reads, marginal time savings
84+
self.read_state = dict()
85+
for old in range(0,8):
86+
for new in range(0,8):
87+
self.read_state[(old,new)] = self.load_state(old, new)
88+
89+
def load_state(self, old_bit_shift, new_bit_shift):
90+
old_bit_shift_inv = 8-old_bit_shift
91+
92+
# Masks
93+
lo_mask = 2**old_bit_shift-1
94+
lo_mask_inv = 0xFF - 2**(8-old_bit_shift)+1
95+
hi_mask = 0xFF ^ lo_mask
96+
hi_mask_inv = 0xFF ^ lo_mask_inv
97+
98+
#last byte parameters
99+
if new_bit_shift == 0: #this means we filled the last byte (8)
100+
last_mask = 0xFF
101+
adjustment = 8-old_bit_shift
102+
adjustment_mask = 2**adjustment-1
103+
else:
104+
last_mask = 2**new_bit_shift-1
105+
adjustment = new_bit_shift-old_bit_shift
106+
adjustment_mask = 2**adjustment-1
107+
108+
return (old_bit_shift_inv, lo_mask, lo_mask_inv, hi_mask,
109+
hi_mask_inv, last_mask, adjustment, adjustment_mask)
110+
83111
'''
84112
Additional Properties
85113
'''
@@ -333,29 +361,20 @@ def read(self, bytes=0, bits=0):
333361
return base+[self.shift(bits)]
334362
return base
335363

336-
# Calculated shifts
364+
# Calculated shifts as our keys
337365
old_bit_shift = self.bit_shift
338366
new_bit_shift = (self.bit_shift+bits) % 8
339367

340-
# Masks
341-
lo_mask = 2**old_bit_shift-1
342-
lo_mask_inv = 0xFF - 2**(8-old_bit_shift)+1
343-
hi_mask = 0xFF ^ lo_mask
344-
hi_mask_inv = 0xFF ^ lo_mask_inv
345-
346-
#last byte parameters
347-
if new_bit_shift == 0: #this means we filled the last byte (8)
348-
last_mask = 0xFF
349-
adjustment = 8-old_bit_shift
350-
else:
351-
last_mask = 2**new_bit_shift-1
352-
adjustment = new_bit_shift-old_bit_shift
368+
# Load the precalculated state variables
369+
(old_bit_shift_inv, lo_mask, lo_mask_inv,
370+
hi_mask, hi_mask_inv, last_mask, adjustment,
371+
adjustment_mask) = self.read_state[(old_bit_shift,new_bit_shift)]
353372

354373
#Set up for the looping with a list, the bytes, and an initial part
355374
raw_bytes = list()
356375
prev, next = self.last_byte, ord(self.read_basic(1))
357376
first = prev & hi_mask
358-
bit_count -= 8-old_bit_shift
377+
bit_count -= old_bit_shift_inv
359378

360379
while bit_count > 0:
361380

@@ -369,15 +388,17 @@ def read(self, bytes=0, bits=0):
369388
# if the adjustment is lower than 0
370389
if adjustment < 0:
371390
first = first >> abs(adjustment)
372-
373-
raw_bytes.append(first | (last >> max(adjustment,0)))
374-
if adjustment > 0:
375-
raw_bytes.append(last & (2**adjustment-1))
391+
raw_bytes.append(first | last)
392+
elif adjustment > 0:
393+
raw_bytes.append(last & adjustment_mask)
394+
raw_bytes.append(first | (last >> adjustment))
395+
else:
396+
raw_bytes.append(first | last)
376397

377398
bit_count = 0
378399

379400
if bit_count > 8: #We can do simple wrapping for middle bytes
380-
second = (next & lo_mask_inv) >> (8-old_bit_shift)
401+
second = (next & lo_mask_inv) >> old_bit_shift_inv
381402
raw_bytes.append(first | second)
382403

383404
#To remain consistent, always shfit these bits into the hi_mask
@@ -530,7 +551,7 @@ def read_header(file):
530551

531552
#Sanity check that the input is in fact an MPQ file
532553
if buffer.empty or buffer.read_hex(4).upper() != "4D50511B":
533-
raise exceptions.FileError("File '%s' is not an MPQ file" % file.name)
554+
raise FileError("File '%s' is not an MPQ file" % file.name)
534555

535556
#Extract replay header data, we are unlikely to ever use most of this
536557
max_data_size = buffer.read_int(LITTLE_ENDIAN)
@@ -668,4 +689,4 @@ def _split_lines(self, text, width):
668689
# Blank lines get eaten by textwrap, put it back with [' ']
669690
lines.extend(new_lines or [' '])
670691

671-
return lines
692+
return lines

0 commit comments

Comments
 (0)