[Zodb-checkins] CVS: ZODB3/ZODB - Connection.py:1.76.4.4 DB.py:1.43.8.2

Jeremy Hylton jeremy at zope.com
Wed Apr 30 17:37:06 EDT 2003


Update of /cvs-repository/ZODB3/ZODB
In directory cvs.zope.org:/tmp/cvs-serv784/ZODB

Modified Files:
      Tag: ZODB3-3_1-branch
	Connection.py DB.py 
Log Message:
Backport atomic invalidations code.


=== ZODB3/ZODB/Connection.py 1.76.4.3 => 1.76.4.4 ===
--- ZODB3/ZODB/Connection.py:1.76.4.3	Tue Jan 14 11:16:17 2003
+++ ZODB3/ZODB/Connection.py	Wed Apr 30 16:37:05 2003
@@ -26,6 +26,7 @@
 from cPickle import Unpickler, Pickler
 from cStringIO import StringIO
 import sys
+import threading
 from time import time
 from types import StringType, ClassType
 
@@ -79,14 +80,28 @@
             # XXX Why do we want version caches to behave this way?
 
             self._cache.cache_drain_resistance = 100
-        self._incrgc=self.cacheGC=cache.incrgc
-        self._invalidated=d={}
-        self._invalid=d.has_key
-        self._committed=[]
+        self._incrgc = self.cacheGC = cache.incrgc
+        self._committed = []
         self._code_timestamp = global_code_timestamp
         self._load_count = 0   # Number of objects unghosted
         self._store_count = 0  # Number of objects stored
 
+        # _invalidated queues invalidate messages delivered from the DB
+        # _inv_lock prevents one thread from modifying the set while
+        # another is processing invalidations.  All the invalidations
+        # from a single transaction should be applied atomically, so
+        # the lock must be held when reading _invalidated.
+
+        # XXX It sucks that we have to hold the lock to read
+        # _invalidated.  Normally, _invalidated is written by call
+        # dict.update, which will execute atomically by virtue of the
+        # GIL.  But some storage might generate oids where hash or
+        # compare invokes Python code.  In that case, the GIL can't
+        # save us.
+        self._inv_lock = threading.Lock()
+        self._invalidated = d = {}
+        self._invalid = d.has_key
+
     def _cache_items(self):
         # find all items on the lru list
         items = self._cache.lru_items()
@@ -141,7 +156,7 @@
             not args and not hasattr(klass,'__getinitargs__')):
             object=klass.__basicnew__()
         else:
-            object=apply(klass,args)
+            object = klass(*args)
             if klass is not ExtensionKlass:
                 object.__dict__.clear()
 
@@ -212,7 +227,7 @@
             # New code is in place.  Start a new cache.
             self._resetCache()
         else:
-            self._cache.invalidate(self._invalidated)
+            self._flush_invalidations()
         self._opened=time()
 
         return self
@@ -233,7 +248,7 @@
         This just deactivates the thing.
         """
         if object is self:
-            self._cache.invalidate(self._invalidated)
+            self._flush_invalidations()
         else:
             assert object._p_oid is not None
             self._cache.invalidate(object._p_oid)
@@ -254,7 +269,6 @@
 
     def close(self):
         self._incrgc() # This is a good time to do some GC
-        db=self._db
 
         # Call the close callbacks.
         if self.__onCloseCallbacks is not None:
@@ -265,10 +279,10 @@
                     LOG('ZODB',ERROR, 'Close callback failed for %s' % f,
                         error=sys.exc_info())
             self.__onCloseCallbacks = None
-        self._db=self._storage=self._tmp=self.new_oid=self._opened=None
-        self._debug_info=()
+        self._storage = self._tmp = self.new_oid = self._opened = None
+        self._debug_info = ()
         # Return the connection to the pool.
-        db._closeConnection(self)
+        self._db._closeConnection(self)
 
     __onCommitActions = None
 
@@ -298,7 +312,7 @@
         elif object._p_changed:
             if invalid(oid) and not hasattr(object, '_p_resolveConflict'):
                 raise ConflictError(object=object)
-            self._invalidating.append(oid)
+            self._modified.append(oid)
 
         else:
             # Nothing to do
@@ -360,7 +374,7 @@
                 #XXX We should never get here
                 if invalid(oid) and not hasattr(object, '_p_resolveConflict'):
                     raise ConflictError(object=object)
-                self._invalidating.append(oid)
+                self._modified.append(oid)
 
             klass = object.__class__
 
@@ -424,9 +438,9 @@
         oids=src._index.keys()
 
         # Copy invalidating and creating info from temporary storage:
-        invalidating=self._invalidating
-        invalidating[len(invalidating):]=oids
-        creating=self._creating
+        modified = self._modified
+        modified[len(modified):] = oids
+        creating = self._creating
         creating[len(creating):]=src._creating
 
         for oid in oids:
@@ -467,15 +481,31 @@
 
     def getVersion(self): return self._version
 
-    def invalidate(self, oid):
-        """Invalidate a particular oid
+    def isReadOnly(self):
+        return self._storage.isReadOnly()
+
+    def invalidate(self, oids):
+        """Invalidate a set of oids.
 
         This marks the oid as invalid, but doesn't actually invalidate
         it.  The object data will be actually invalidated at certain
         transaction boundaries.
         """
-        assert oid is not None
-        self._invalidated[oid] = 1
+        self._inv_lock.acquire()
+        try:
+            self._invalidated.update(oids)
+        finally:
+            self._inv_lock.release()
+
+    def _flush_invalidations(self):
+        self._inv_lock.acquire()
+        try:
+            self._cache.invalidate(self._invalidated)
+            self._invalidated.clear()
+        finally:
+            self._inv_lock.release()
+        # Now is a good time to collect some garbage
+        self._cache.incrgc()
 
     def modifiedInVersion(self, oid):
         try: return self._db.modifiedInVersion(oid)
@@ -496,8 +526,8 @@
     def root(self):
         return self['\0\0\0\0\0\0\0\0']
 
-    def setstate(self, object):
-        oid = object._p_oid
+    def setstate(self, obj):
+        oid = obj._p_oid
 
         if self._storage is None:
             msg = ("Shouldn't load state for %s "
@@ -506,52 +536,20 @@
             raise RuntimeError(msg)
 
         try:
+            # Avoid reading data from a transaction that committed
+            # after the current transaction started, as that might
+            # lead to mixing of cached data from earlier transactions
+            # and new inconsistent data.
+            #
+            # Wait for check until after data is loaded from storage
+            # to avoid time-of-check to time-of-use race.
             p, serial = self._storage.load(oid, self._version)
             self._load_count = self._load_count + 1
-
-            # XXX this is quite conservative!
-            # We need, however, to avoid reading data from a transaction
-            # that committed after the current "session" started, as
-            # that might lead to mixing of cached data from earlier
-            # transactions and new inconsistent data.
-            #
-            # Note that we (carefully) wait until after we call the
-            # storage to make sure that we don't miss an invaildation
-            # notifications between the time we check and the time we
-            # read.
-            if self._invalid(oid):
-                if not hasattr(object.__class__, '_p_independent'):
-                    get_transaction().register(self)
-                    raise ReadConflictError(object=object)
-                invalid = 1
-            else:
-                invalid = 0
-
-            file = StringIO(p)
-            unpickler = Unpickler(file)
-            unpickler.persistent_load = self._persistent_load
-            unpickler.load()
-            state = unpickler.load()
-
-            if hasattr(object, '__setstate__'):
-                object.__setstate__(state)
-            else:
-                d = object.__dict__
-                for k, v in state.items():
-                    d[k] = v
-
-            object._p_serial = serial
-
+            invalid = self._is_invalidated(obj)
+            self._set_ghost_state(obj, p)
+            obj._p_serial = serial
             if invalid:
-                if object._p_independent():
-                    try:
-                        del self._invalidated[oid]
-                    except KeyError:
-                        pass
-                else:
-                    get_transaction().register(self)
-                    raise ConflictError(object=object)
-
+                self._handle_independent(obj)
         except ConflictError:
             raise
         except:
@@ -559,6 +557,56 @@
                 error=sys.exc_info())
             raise
 
+    def _is_invalidated(self, obj):
+        # Helper method for setstate() covers three cases:
+        # returns false if obj is valid
+        # returns true if obj was invalidation, but is independent
+        # otherwise, raises ConflictError for invalidated objects
+        self._inv_lock.acquire()
+        try:
+            if self._invalidated.has_key(obj._p_oid):
+                # Defer _p_independent() call until state is loaded.
+                ind = getattr(obj, "_p_independent", None)
+                if ind is not None:
+                    # Defer _p_independent() call until state is loaded.
+                    return 1
+                else:
+                    raise ReadConflictError(object=obj)
+            else:
+                return 0
+        finally:
+            self._inv_lock.release()
+
+    def _set_ghost_state(self, obj, p):
+        file = StringIO(p)
+        unpickler = Unpickler(file)
+        unpickler.persistent_load = self._persistent_load
+        unpickler.load()
+        state = unpickler.load()
+
+        setstate = getattr(obj, "__setstate__", None)
+        if setstate is None:
+            obj.update(state)
+        else:
+            setstate(state)
+
+    def _handle_independent(self, obj):
+        # Helper method for setstate() handles possibly independent objects
+        # Call _p_independent(), if it returns True, setstate() wins.
+        # Otherwise, raise a ConflictError.
+
+        if obj._p_independent():
+            self._inv_lock.acquire()
+            try:
+                try:
+                    del self._invalidated[obj._p_oid]
+                except KeyError:
+                    pass
+            finally:
+                self._inv_lock.release()
+        else:
+            raise ReadConflictError(object=obj)
+        
     def oldstate(self, object, serial):
         oid=object._p_oid
         p = self._storage.loadSerial(oid, serial)
@@ -587,7 +635,7 @@
                     % getattr(object,'__name__','(?)'))
                 return
 
-            copy=apply(klass,args)
+            copy = klass(*args)
             object.__dict__.clear()
             object.__dict__.update(copy.__dict__)
 
@@ -603,14 +651,13 @@
         if self.__onCommitActions is not None:
             del self.__onCommitActions
         self._storage.tpc_abort(transaction)
-        self._cache.invalidate(self._invalidated)
-        self._cache.invalidate(self._invalidating)
+        self._cache.invalidate(self._modified)
+        self._flush_invalidations()
         self._invalidate_creating()
 
     def tpc_begin(self, transaction, sub=None):
-        self._invalidating = []
+        self._modified = []
         self._creating = []
-
         if sub:
             # Sub-transaction!
             if self._tmp is None:
@@ -675,10 +722,10 @@
 
     def tpc_finish(self, transaction):
         # It's important that the storage call the function we pass
-        # (self._invalidate_invalidating) while it still has it's
-        # lock.  We don't want another thread to be able to read any
-        # updated data until we've had a chance to send an
-        # invalidation message to all of the other connections!
+        # while it still has it's lock.  We don't want another thread
+        # to be able to read any updated data until we've had a chance
+        # to send an invalidation message to all of the other
+        # connections!
 
         if self._tmp is not None:
             # Commiting a subtransaction!
@@ -687,25 +734,20 @@
             self._storage._creating[:0]=self._creating
             del self._creating[:]
         else:
-            self._db.begin_invalidation()
-            self._storage.tpc_finish(transaction,
-                                     self._invalidate_invalidating)
+            def callback():
+                d = {}
+                for oid in self._modified:
+                    d[oid] = 1 
+                self._db.invalidate(d, self)
+            self._storage.tpc_finish(transaction, callback)
 
-        self._cache.invalidate(self._invalidated)
-        self._incrgc() # This is a good time to do some GC
-
-    def _invalidate_invalidating(self):
-        for oid in self._invalidating:
-            assert oid is not None
-            self._db.invalidate(oid, self)
-        self._db.finish_invalidation()
+        self._flush_invalidations()
 
     def sync(self):
         get_transaction().abort()
         sync=getattr(self._storage, 'sync', 0)
         if sync != 0: sync()
-        self._cache.invalidate(self._invalidated)
-        self._incrgc() # This is a good time to do some GC
+        self._flush_invalidations()
 
     def getDebugInfo(self):
         return self._debug_info


=== ZODB3/ZODB/DB.py 1.43.8.1 => 1.43.8.2 ===
--- ZODB3/ZODB/DB.py:1.43.8.1	Tue Nov 12 15:18:09 2002
+++ ZODB3/ZODB/DB.py	Wed Apr 30 16:37:05 2003
@@ -19,13 +19,19 @@
 import cPickle, cStringIO, sys, POSException, UndoLogCompatible
 from Connection import Connection
 from bpthread import allocate_lock
-from Transaction import Transaction
+from Transaction import Transaction, get_transaction
 from referencesf import referencesf
 from time import time, ctime
 from zLOG import LOG, ERROR
 
 from types import StringType
 
+def list2dict(L):
+    d = {}
+    for elt in L:
+        d[elt] = 1
+    return d
+
 class DB(UndoLogCompatible.UndoLogCompatible):
     """The Object Database
 
@@ -153,8 +159,10 @@
                 self._temps=t
         finally: self._r()
 
-    def abortVersion(self, version):
-        AbortVersion(self, version)
+    def abortVersion(self, version, transaction=None):
+        if transaction is None:
+            transaction = get_transaction()
+        transaction.register(AbortVersion(self, version))
 
     def cacheDetail(self):
         """Return information on objects in the various caches
@@ -245,10 +253,16 @@
         m.sort()
         return m
 
-    def close(self): self._storage.close()
+    def close(self):
+        self._storage.close()
+        for x, allocated in self._pools[1]:
+            for c in allocated:
+                c._breakcr()
 
-    def commitVersion(self, source, destination=''):
-        CommitVersion(self, source, destination)
+    def commitVersion(self, source, destination='', transaction=None):
+        if transaction is None:
+            transaction = get_transaction()
+        transaction.register(CommitVersion(self, source, destination))
 
     def exportFile(self, oid, file=None):
         raise 'Not yet implemented'
@@ -277,17 +291,7 @@
     def importFile(self, file):
         raise 'Not yet implemented'
 
-    def begin_invalidation(self):
-        # Must be called before first call to invalidate and before
-        # the storage lock is held.
-        self._a()
-
-    def finish_invalidation(self):
-        # Must be called after begin_invalidation() and after final
-        # invalidate() call.
-        self._r()
-
-    def invalidate(self, oid, connection=None, version='',
+    def invalidate(self, oids, connection=None, version='',
                    rc=sys.getrefcount):
         """Invalidate references to a given oid.
 
@@ -299,9 +303,11 @@
         if connection is not None:
             version=connection._version
         # Update modified in version cache
-        h=hash(oid)%131
-        o=self._miv_cache.get(h, None)
-        if o is not None and o[0]==oid: del self._miv_cache[h]
+        # XXX must make this work with list or dict to backport to 2.6
+        for oid in oids.keys():
+            h=hash(oid)%131
+            o=self._miv_cache.get(h, None)
+            if o is not None and o[0]==oid: del self._miv_cache[h]
 
         # Notify connections
         for pool, allocated in self._pools[1]:
@@ -310,7 +316,7 @@
                     (not version or cc._version==version)):
                     if rc(cc) <= 3:
                         cc.close()
-                    cc.invalidate(oid)
+                    cc.invalidate(oids)
 
         temps=self._temps
         if temps:
@@ -319,7 +325,7 @@
                 if rc(cc) > 3:
                     if (cc is not connection and
                         (not version or cc._version==version)):
-                        cc.invalidate(oid)
+                        cc.invalidate(oids)
                     t.append(cc)
                 else: cc.close()
             self._temps=t
@@ -352,9 +358,6 @@
         Note that the connection pool is managed as a stack, to increate the
         likelihood that the connection's stack will include useful objects.
         """
-        if type(version) is not StringType:
-            raise POSException.Unimplemented, 'temporary versions'
-
         self._a()
         try:
 
@@ -544,7 +547,7 @@
 
     def cacheStatistics(self): return () # :(
 
-    def undo(self, id):
+    def undo(self, id, transaction=None):
         storage=self._storage
         try: supportsTransactionalUndo = storage.supportsTransactionalUndo
         except AttributeError:
@@ -554,11 +557,15 @@
 
         if supportsTransactionalUndo:
             # new style undo
-            TransactionalUndo(self, id)
+            if transaction is None:
+                transaction = get_transaction()
+            transaction.register(TransactionalUndo(self, id))
         else:
             # fall back to old undo
+            d = {}
             for oid in storage.undo(id):
-                self.invalidate(oid)
+                d[oid] = 1
+            self.invalidate(d)
 
     def versionEmpty(self, version):
         return self._storage.versionEmpty(version)
@@ -578,7 +585,6 @@
         self.tpc_vote=s.tpc_vote
         self.tpc_finish=s.tpc_finish
         self._sortKey=s.sortKey
-        get_transaction().register(self)
 
     def sortKey(self):
         return "%s:%s" % (self._sortKey(), id(self))
@@ -586,14 +592,14 @@
     def abort(self, reallyme, t): pass
 
     def commit(self, reallyme, t):
-        db=self._db
         dest=self._dest
-        oids=db._storage.commitVersion(self._version, dest, t)
-        for oid in oids: db.invalidate(oid, version=dest)
+        oids = self._db._storage.commitVersion(self._version, dest, t)
+        oids = list2dict(oids)
+        self._db.invalidate(oids, version=dest)
         if dest:
             # the code above just invalidated the dest version.
             # now we need to invalidate the source!
-            for oid in oids: db.invalidate(oid, version=self._version)
+            self._db.invalidate(oids, version=self._version)
 
 class AbortVersion(CommitVersion):
     """An object that will see to version abortion
@@ -602,11 +608,9 @@
     """
 
     def commit(self, reallyme, t):
-        db=self._db
         version=self._version
-        oids = db._storage.abortVersion(version, t)
-        for oid in oids:
-            db.invalidate(oid, version=version)
+        oids = self._db._storage.abortVersion(version, t)
+        self._db.invalidate(list2dict(oids), version=version)
 
 
 class TransactionalUndo(CommitVersion):
@@ -615,12 +619,10 @@
     in cooperation with a transaction manager.
     """
 
-    # I'm lazy. I'm reusing __init__ and abort and reusing the
-    # version attr for the transavtion id. There's such a strong
-    # similarity of rythm, that I think it's justified.
+    # I (Jim) am lazy.  I'm reusing __init__ and abort and reusing the
+    # version attr for the transaction id.  There's such a strong
+    # similarity of rhythm that I think it's justified.
 
     def commit(self, reallyme, t):
-        db=self._db
-        oids=db._storage.transactionalUndo(self._version, t)
-        for oid in oids:
-            db.invalidate(oid)
+        oids = self._db._storage.transactionalUndo(self._version, t)
+        self._db.invalidate(list2dict(oids))




More information about the Zodb-checkins mailing list