Skip to content

Commit 1b656fc

Browse files
committed
Refactor ReplayBuffer into Byte/BitDecoders.
While we are at it, improve clarity of the method names and fix several small issues with the read_struct logic. Also improve method and member documentation in preparation for auto-docs.
1 parent b1230a8 commit 1b656fc

File tree

5 files changed

+441
-478
lines changed

5 files changed

+441
-478
lines changed

sc2reader/decoders.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import
3+
4+
from cStringIO import StringIO
5+
6+
import struct
7+
8+
class ByteDecoder(object):
9+
10+
#: The StringIO object used internaly for reading from the
11+
#: decoder contents. cStringIO is faster than managing our
12+
#: own string access in python. For PyPy installations a
13+
#: managed string implementation might be faster.
14+
_buffer = None
15+
16+
#: The string buffer being decoded. A direct reference
17+
#: is kept around to make read_range and peek faster.
18+
_contents = ""
19+
20+
def __init__(self, contents, endian):
21+
""" Accepts both strings and files implementing ``read()`` and
22+
decodes them in the specified endian format.
23+
"""
24+
if hasattr(contents,'read'):
25+
self._contents = contents.read()
26+
else:
27+
self._contents = contents
28+
29+
self._buffer = StringIO(self._contents)
30+
self.length = len(self._contents)
31+
32+
# Expose the basic StringIO interface
33+
self.read = self._buffer.read
34+
self.seek = self._buffer.seek
35+
self.tell = self._buffer.tell
36+
37+
# decode the endian value if necessary
38+
endian = endian.lower()
39+
if endian.lower() == 'little':
40+
endian = "<"
41+
elif endian.lower() == 'big':
42+
endian = ">"
43+
elif endian not in ('<','>'):
44+
raise ValueError("Endian must be one of 'little', '<', 'big', or '>' but was: "+endian)
45+
46+
# Pre-compiling
47+
self._unpack_int = struct.Struct(endian+'I').unpack
48+
self._unpack_short = struct.Struct(endian+'H').unpack
49+
self._unpack_longlong = struct.Struct(endian+'Q').unpack
50+
self._unpack_bytes = lambda bytes: bytes if endian == '>' else bytes[::-1]
51+
52+
def done(self):
53+
""" Returns true when all bytes have been decoded """
54+
return self.tell() == self.length
55+
56+
def read_range(self, start, end):
57+
""" Returns the raw byte string from the indicated address range """
58+
return self._contents[start:end]
59+
60+
def peek(self, count):
61+
""" Returns the raw byte string for the next ``count`` bytes """
62+
start = self.tell()
63+
return self._contents[start:start+count]
64+
65+
def read_uint8(self):
66+
""" Returns the next byte as an unsigned integer """
67+
return ord(self.read(1))
68+
69+
def read_uint16(self):
70+
""" Returns the next two bytes as an unsigned integer """
71+
return self._unpack_short(self.read(2))[0]
72+
73+
def read_uint32(self):
74+
""" Returns the next four bytes as an unsigned integer """
75+
return self._unpack_int(self.read(4))[0]
76+
77+
def read_uint64(self):
78+
""" Returns the next eight bytes as an unsigned integer """
79+
return self._unpack_longlong(self.read(8))[0]
80+
81+
def read_bytes(self, count):
82+
""" Returns the next ``count`` bytes as a byte string """
83+
return self._unpack_bytes(self.read(count))
84+
85+
class BitPackedDecoder(object):
86+
87+
#: The ByteDecoder used internally to read byte
88+
#: aligned values.
89+
_buffer = None
90+
91+
#: Tracks the how many bits have already been used
92+
#: from the current byte.
93+
_bit_shift = 0
94+
95+
#: Holds the byte, if any, that hasn't had its bits
96+
#: fully used yet.
97+
_next_byte = None
98+
99+
#: Maps bit shifts to low bit masks used for grabbing
100+
#: the first bits off of the next byte.
101+
_lo_masks = [0x00, 0x01, 0x03, 0x07, 0x0F, 0x1F, 0x3F, 0x7F, 0xFF]
102+
103+
#: Maps bit shifts to high bit masks used for grabbing
104+
#: the remaining bits off of the previous byte.
105+
_hi_masks = [0xFF ^ mask for mask in _lo_masks]
106+
107+
#: Maps bit shifts to high and low bit masks. Used for
108+
#: joining bytes when we are not byte aligned.
109+
_bit_masks = zip(_lo_masks, _hi_masks)
110+
111+
def __init__(self, contents):
112+
self._buffer = ByteDecoder(contents, endian='BIG')
113+
114+
# Partially expose the ByteBuffer interface
115+
self.length = self._buffer.length
116+
self.tell = self._buffer.tell
117+
self.peek = self._buffer.peek
118+
self.read_range = self._buffer.read_range
119+
120+
# Reduce the number of lookups required to read
121+
self._read = self._buffer.read
122+
123+
def done(self):
124+
""" Returns true when all bits in the buffer have been used"""
125+
return self.tell() == self.length and self._bit_shift == 0
126+
127+
def byte_align(self):
128+
""" Moves cursor to the beginning of the next byte """
129+
self._next_byte = None
130+
self._bit_shift = 0
131+
132+
def read_uint8(self):
133+
""" Returns the next 8 bits as an unsigned integer """
134+
data = self._buffer.read_uint8()
135+
136+
if self._bit_shift != 0:
137+
lo_mask, hi_mask = self._bit_masks[self._bit_shift]
138+
hi_bits = self._next_byte & hi_mask
139+
lo_bits = data & lo_mask
140+
self._next_byte = data
141+
data = hi_bits | lo_bits
142+
143+
return data
144+
145+
def read_uint16(self):
146+
""" Returns the next 16 bits as an unsigned integer """
147+
data = self._buffer.read_uint16()
148+
149+
if self._bit_shift != 0:
150+
lo_mask, hi_mask = self._bit_masks[self._bit_shift]
151+
hi_bits = (self._next_byte & hi_mask) << 8
152+
mi_bits = (data & 0xFF00) >> (8-self._bit_shift)
153+
lo_bits = (data & lo_mask)
154+
self._next_byte = data & 0xFF
155+
data = hi_bits | mi_bits | lo_bits
156+
157+
return data
158+
159+
def read_uint32(self):
160+
""" Returns the next 32 bits as an unsigned integer """
161+
data = self._buffer.read_uint32()
162+
163+
if self._bit_shift != 0:
164+
lo_mask, hi_mask = self._bit_masks[self._bit_shift]
165+
hi_bits = (self._next_byte & hi_mask) << 24
166+
mi_bits = (data & 0xFFFFFF00) >> (8-self._bit_shift)
167+
lo_bits = (data & lo_mask)
168+
self._next_byte = data & 0xFF
169+
data = hi_bits | mi_bits | lo_bits
170+
171+
return data
172+
173+
def read_uint64(self):
174+
""" Returns the next 64 bits as an unsigned integer """
175+
data = self._buffer.read_uint64()
176+
177+
if self._bit_shift != 0:
178+
lo_mask, hi_mask = self._bit_masks[self._bit_shift]
179+
hi_bits = (self._next_byte & hi_mask) << 56
180+
mi_bits = (data & 0xFFFFFFFFFFFFFF00) >> (8-self._bit_shift)
181+
lo_bits = (data & lo_mask)
182+
self._next_byte = data & 0xFF
183+
data = hi_bits | mi_bits | lo_bits
184+
185+
return data
186+
187+
def read_vint(self):
188+
""" Reads a signed integer of variable length """
189+
byte = self.read_uint8()
190+
negative = byte & 0x01
191+
result = (byte & 0x7F) >> 1
192+
bits = 6
193+
while byte & 0x80:
194+
byte = self.read_uint8()
195+
result |= (byte & 0x7F) << bits
196+
bits += 7
197+
return -result if negative else result
198+
199+
def read_aligned_bytes(self, count):
200+
""" Skips to the beginning of the next byte and returns the next ``count`` bytes as a byte string """
201+
self.byte_align()
202+
return self._buffer.read_bytes(count)
203+
204+
def read_bytes(self, count):
205+
""" Returns the next ``count*8`` bits as a byte string """
206+
data = self._buffer.read_bytes(count)
207+
208+
if self._bit_shift != 0:
209+
temp_buffer = StringIO()
210+
prev_byte = self._next_byte
211+
lo_mask, hi_mask = self._bit_masks[self._bit_shift]
212+
for next_byte in struct.unpack("B"*count, data):
213+
temp_buffer.write(chr(prev_byte & hi_mask | next_byte & lo_mask))
214+
prev_byte = next_byte
215+
216+
self._next_byte = prev_byte
217+
data = temp_buffer.getvalue()
218+
temp_buffer.truncate(0)
219+
220+
return data
221+
222+
def read_bits(self, count):
223+
""" Returns the next ``count`` bits as an unsigned integer """
224+
result = 0
225+
bits = count
226+
bit_shift = self._bit_shift
227+
228+
# If we've got a byte in progress use it first
229+
if bit_shift!=0:
230+
bits_left = 8-bit_shift
231+
232+
if bits_left < bits:
233+
bits -= bits_left
234+
result = (self._next_byte >> bit_shift) << bits
235+
elif bits_left > bits:
236+
self._bit_shift += bits
237+
return (self._next_byte >> bit_shift) & self._lo_masks[bits]
238+
else:
239+
self._bit_shift = 0
240+
return self._next_byte >> bit_shift
241+
242+
# Then grab any additional whole bytes as needed
243+
if bits >= 8:
244+
bytes = bits/8
245+
246+
if bytes == 1:
247+
bits -= 8
248+
result |= self._buffer.read_uint8() << bits
249+
250+
elif bytes == 2:
251+
bits -= 16
252+
result |= self._buffer.read_uint16() << bits
253+
254+
elif bytes == 4:
255+
bits -= 32
256+
result |= self._buffer.read_uint32() << bits
257+
258+
else:
259+
for byte in struct.unpack("B"*bytes, self._read(bytes)):
260+
bits -= 8
261+
result |= byte << bits
262+
263+
# Grab any trailing bits from the next byte
264+
if bits != 0:
265+
self._next_byte = ord(self._read(1))
266+
result |= self._next_byte & self._lo_masks[bits]
267+
268+
self._bit_shift = bits
269+
return result
270+
271+
def read_frames(self):
272+
""" Reads a frame count as an unsigned integer """
273+
byte = self.read_uint8()
274+
time, additional_bytes = byte >> 2, byte & 0x03
275+
if additional_bytes == 0:
276+
return time
277+
elif additional_bytes == 1:
278+
return time << 8 | self.read_uint8()
279+
elif additional_bytes == 2:
280+
return time << 16 | self.read_uint16()
281+
elif additional_bytes == 3:
282+
return time << 24 | self.read_uint16() << 8 | self.read_uint8()
283+
284+
def read_struct(self, datatype=None):
285+
""" Reads a nested data structure. If the type is not specified the
286+
first byte is used as the type identifier.
287+
"""
288+
self.byte_align() # I think this is true
289+
datatype = self.read_uint8() if datatype == None else datatype
290+
291+
if datatype == 0x00: # array
292+
data = [self.read_struct() for i in xrange(self.read_vint())]
293+
294+
elif datatype == 0x01: # bitarray, weird alignment requirements
295+
bits = self.read_vint()
296+
data = self.read_bits(bits)
297+
298+
elif datatype == 0x02: # blob
299+
length = self.read_vint()
300+
data = self.read_bytes(length)
301+
302+
elif datatype == 0x03: # choice
303+
flag = self.read_vint()
304+
data = self.read_struct()
305+
306+
elif datatype == 0x04: # optional
307+
exists = self.read_uint8() != 0
308+
data = self.read_struct() if exists else None
309+
310+
elif datatype == 0x05: # Struct
311+
data = dict()
312+
entries = self.read_vint()
313+
for i in xrange(entries):
314+
key = self.read_vint() # Must be read first
315+
data[key] = self.read_struct()
316+
317+
elif datatype == 0x06: # u8
318+
data = self.read_uint8()
319+
320+
elif datatype == 0x07: # u32
321+
data = self.read_bytes(4) #self.read_uint32()
322+
323+
elif datatype == 0x08: # u64
324+
data = self.read_unit64()
325+
326+
elif datatype == 0x09: # vint
327+
data = self.read_vint()
328+
329+
else:
330+
if debug: print prefix
331+
raise TypeError("Unknown Data Structure: '%s'" % datatype)
332+
333+
return data

sc2reader/objects.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,16 @@ def hash(self):
7474

7575
class Attribute(object):
7676

77-
def __init__(self, data):
78-
#Unpack the data values and add a default name of unknown to be
79-
#overridden by known attributes; acts as a flag for exclusion
80-
self.header, self.id, self.player, self.value, self.name = tuple(data+["Unknown"])
77+
def __init__(self, header, attr_id, player, value):
78+
self.header = header
79+
self.id = attr_id
80+
self.player = player
8181

82-
if self.id in LOBBY_PROPERTIES:
82+
if self.id not in LOBBY_PROPERTIES:
83+
raise ValueError("Unknown attribute id: "+self.id)
84+
else:
8385
self.name, lookup = LOBBY_PROPERTIES[self.id]
84-
if lookup:
85-
if callable(lookup):
86-
self.value = lookup(self.value)
87-
else:
88-
self.value = lookup[self.value]
86+
self.value = lookup[value.strip("\x00 ")[::-1]]
8987

9088
def __repr__(self):
9189
return str(self)

0 commit comments

Comments
 (0)