@@ -100,20 +100,26 @@ def set_state(self, state):
100100 self .states .remove (* others )
101101 if state not in already_set :
102102 self .states .add (state )
103+ self .state_cache = None # invalidate cache
103104
104105 def unset_state (self , state_type ):
105106 """Unset state of type so no state of that type is any longer set."""
106107 self .states .remove (* self .states .filter (type = state_type ))
108+ self .state_cache = None # invalidate cache
107109
108110 def get_state (self , state_type = None ):
109- """Get state of type, or default state for document type if not specified."""
111+ """Get state of type, or default state for document type if
112+ not specified. Uses a local cache to speed multiple state
113+ reads up."""
110114 if state_type == None :
111115 state_type = self .type_id
112116
113- try :
114- return self .states .get (type = state_type )
115- except State .DoesNotExist :
116- return None
117+ if not hasattr (self , "state_cache" ) or self .state_cache == None :
118+ self .state_cache = {}
119+ for s in self .states .all ().select_related ():
120+ self .state_cache [s .type_id ] = s
121+
122+ return self .state_cache .get (state_type , None )
117123
118124 def get_state_slug (self , state_type = None ):
119125 """Get state of type, or default if not specified, returning
0 commit comments