[Zodb-checkins] SVN: ZODB/branches/3.8/src/ZEO/ Refactored cache verification to fix threading bugs during connection.

Jim Fulton jim at zope.com
Mon Jul 14 10:59:14 EDT 2008


Log message for revision 88353:
  Refactored cache verification to fix threading bugs during connection.
  
  Changed connections to work with unset (None) clients.  Messages
  aren't forwarded until the client is set.  This is to prevent sending
  spurious invalidation messages until a client is ready to recieve them.
  

Changed:
  U   ZODB/branches/3.8/src/ZEO/ClientStorage.py
  A   ZODB/branches/3.8/src/ZEO/tests/invalidations_while_connecting.test
  U   ZODB/branches/3.8/src/ZEO/tests/testConnection.py
  U   ZODB/branches/3.8/src/ZEO/zrpc/client.py
  U   ZODB/branches/3.8/src/ZEO/zrpc/connection.py

-=-
Modified: ZODB/branches/3.8/src/ZEO/ClientStorage.py
===================================================================
--- ZODB/branches/3.8/src/ZEO/ClientStorage.py	2008-07-14 14:59:06 UTC (rev 88352)
+++ ZODB/branches/3.8/src/ZEO/ClientStorage.py	2008-07-14 14:59:11 UTC (rev 88353)
@@ -491,6 +491,7 @@
             # If we are upgrading from a read-only fallback connection,
             # we must close the old connection to prevent it from being
             # used while the cache is verified against the new connection.
+            self._connection.register_object(None) # Don't call me!
             self._connection.close()
             self._connection = None
             self._ready.clear()
@@ -558,54 +559,6 @@
         else:
             return '%s:%s' % (self._storage, self._server_addr)
 
-    def verify_cache(self, server):
-        """Internal routine called to verify the cache.
-
-        The return value (indicating which path we took) is used by
-        the test suite.
-        """
-
-        # If verify_cache() finishes the cache verification process,
-        # it should set self._server.  If it goes through full cache
-        # verification, then endVerify() should self._server.
-
-        last_inval_tid = self._cache.getLastTid()
-        if last_inval_tid is not None:
-            ltid = server.lastTransaction()
-            if ltid == last_inval_tid:
-                log2("No verification necessary (last_inval_tid up-to-date)")
-                self._server = server
-                self._ready.set()
-                return "no verification"
-
-            # log some hints about last transaction
-            log2("last inval tid: %r %s\n"
-                 % (last_inval_tid, tid2time(last_inval_tid)))
-            log2("last transaction: %r %s" %
-                 (ltid, ltid and tid2time(ltid)))
-
-            pair = server.getInvalidations(last_inval_tid)
-            if pair is not None:
-                log2("Recovering %d invalidations" % len(pair[1]))
-                self.invalidateTransaction(*pair)
-                self._server = server
-                self._ready.set()
-                return "quick verification"
-
-        log2("Verifying cache")
-        # setup tempfile to hold zeoVerify results
-        self._tfile = tempfile.TemporaryFile(suffix=".inv")
-        self._pickler = cPickle.Pickler(self._tfile, 1)
-        self._pickler.fast = 1 # Don't use the memo
-
-        # TODO:  should batch these operations for efficiency; would need
-        # to acquire lock ...
-        for oid, tid, version in self._cache.contents():
-            server.verify(oid, version, tid)
-        self._pending_server = server
-        server.endZeoVerify()
-        return "full verification"
-
     ### Is there a race condition between notifyConnected and
     ### notifyDisconnected? In Particular, what if we get
     ### notifyDisconnected in the middle of notifyConnected?
@@ -1162,7 +1115,7 @@
             return
 
         for oid, version, data in self._tbuf:
-            self._cache.invalidate(oid, version, tid)
+            self._cache.invalidate(oid, version, tid, False)
             # If data is None, we just invalidate.
             if data is not None:
                 s = self._seriald[oid]
@@ -1224,8 +1177,6 @@
         """Storage API: return a sequence of versions in the storage."""
         return self._server.versions(max)
 
-    # Below are methods invoked by the StorageServer
-
     def serialnos(self, args):
         """Server callback to pass a list of changed (oid, serial) pairs."""
         self._serials.extend(args)
@@ -1234,6 +1185,57 @@
         """Server callback to update the info dictionary."""
         self._info.update(dict)
 
+    def verify_cache(self, server):
+        """Internal routine called to verify the cache.
+
+        The return value (indicating which path we took) is used by
+        the test suite.
+        """
+
+        self._pending_server = server
+
+        # setup tempfile to hold zeoVerify results and interim
+        # invalidation results
+        self._tfile = tempfile.TemporaryFile(suffix=".inv")
+        self._pickler = cPickle.Pickler(self._tfile, 1)
+        self._pickler.fast = 1 # Don't use the memo
+
+        # allow incoming invalidations:
+        self._connection.register_object(self)
+
+        # If verify_cache() finishes the cache verification process,
+        # it should set self._server.  If it goes through full cache
+        # verification, then endVerify() should self._server.
+
+        last_inval_tid = self._cache.getLastTid()
+        if last_inval_tid is not None:
+            ltid = server.lastTransaction()
+            if ltid == last_inval_tid:
+                log2("No verification necessary (last_inval_tid up-to-date)")
+                self.finish_verification()
+                return "no verification"
+
+            # log some hints about last transaction
+            log2("last inval tid: %r %s\n"
+                 % (last_inval_tid, tid2time(last_inval_tid)))
+            log2("last transaction: %r %s" %
+                 (ltid, ltid and tid2time(ltid)))
+
+            pair = server.getInvalidations(last_inval_tid)
+            if pair is not None:
+                log2("Recovering %d invalidations" % len(pair[1]))
+                self.finish_verification(pair)
+                return "quick verification"
+
+        log2("Verifying cache")
+
+        # TODO:  should batch these operations for efficiency; would need
+        # to acquire lock ...
+        for oid, tid, version in self._cache.contents():
+            server.verify(oid, version, tid)
+        server.endZeoVerify()
+        return "full verification"
+
     def invalidateVerify(self, args):
         """Server callback to invalidate an (oid, version) pair.
 
@@ -1245,68 +1247,93 @@
             # This should never happen.  TODO:  assert it doesn't, or log
             # if it does.
             return
-        self._pickler.dump(args)
+        oid, version = args
+        self._pickler.dump((oid, version, None))
 
-    def _process_invalidations(self, invs):
-        # Invalidations are sent by the ZEO server as a sequence of
-        # oid, version pairs.  The DB's invalidate() method expects a
-        # dictionary of oids.
+    def endVerify(self):
+        """Server callback to signal end of cache validation."""
 
+        log2("endVerify finishing")
+        self.finish_verification()
+        log2("endVerify finished")
+
+    def finish_verification(self, catch_up=None):
         self._lock.acquire()
         try:
-            # versions maps version names to dictionary of invalidations
-            versions = {}
-            for oid, version, tid in invs:
-                if oid == self._load_oid:
-                    self._load_status = 0
-                self._cache.invalidate(oid, version, tid)
-                oids = versions.get((version, tid))
-                if not oids:
-                    versions[(version, tid)] = [oid]
-                else:
-                    oids.append(oid)
+            if catch_up:
+                # process catch-up invalidations
+                tid, invalidations = catch_up
+                self._process_invalidations(
+                    (oid, version, tid)
+                    for oid, version in invalidations
+                    )
+            
+            if self._pickler is None:
+                return
+            # write end-of-data marker
+            self._pickler.dump((None, None, None))
+            self._pickler = None
+            self._tfile.seek(0)
+            unpickler = cPickle.Unpickler(self._tfile)
+            min_tid = self._cache.getLastTid()
+            def InvalidationLogIterator():
+                while 1:
+                    oid, version, tid = unpickler.load()
+                    if oid is None:
+                        break
+                    if ((tid is None)
+                        or (min_tid is None)
+                        or (tid > min_tid)
+                        ):
+                        yield oid, version, tid
 
-            if self._db is not None:
-                for (version, tid), d in versions.items():
-                    self._db.invalidate(tid, d, version=version)
+            self._process_invalidations(InvalidationLogIterator())
+            self._tfile.close()
+            self._tfile = None
         finally:
             self._lock.release()
 
-    def endVerify(self):
-        """Server callback to signal end of cache validation."""
-        if self._pickler is None:
-            return
-        # write end-of-data marker
-        self._pickler.dump((None, None))
-        self._pickler = None
-        self._tfile.seek(0)
-        f = self._tfile
-        self._tfile = None
-        self._process_invalidations(InvalidationLogIterator(f))
-        f.close()
-
-        log2("endVerify finishing")
         self._server = self._pending_server
         self._ready.set()
-        self._pending_conn = None
-        log2("endVerify finished")
+        self._pending_server = None
 
+
     def invalidateTransaction(self, tid, args):
-        """Invalidate objects modified by tid."""
+        """Server callback: Invalidate objects modified by tid."""
         self._lock.acquire()
         try:
-            self._cache.setLastTid(tid)
+            if self._pickler is not None:
+                log2("Transactional invalidation during cache verification",
+                     level=BLATHER)
+                for oid, version in args:
+                    self._pickler.dump((oid, version, tid))
+                return
+            self._process_invalidations([(oid, version, tid)
+                                         for oid, version in args])
         finally:
             self._lock.release()
-        if self._pickler is not None:
-            log2("Transactional invalidation during cache verification",
-                 level=BLATHER)
-            for t in args:
-                self._pickler.dump(t)
-            return
-        self._process_invalidations([(oid, version, tid)
-                                     for oid, version in args])
 
+    def _process_invalidations(self, invs):
+        # Invalidations are sent by the ZEO server as a sequence of
+        # oid, version, tid triples.  The DB's invalidate() method expects a
+        # dictionary of oids.
+
+        # versions maps version names to dictionary of invalidations
+        versions = {}
+        for oid, version, tid in invs:
+            if oid == self._load_oid:
+                self._load_status = 0
+            self._cache.invalidate(oid, version, tid)
+            oids = versions.get((version, tid))
+            if not oids:
+                versions[(version, tid)] = [oid]
+            else:
+                oids.append(oid)
+
+        if self._db is not None:
+            for (version, tid), d in versions.items():
+                self._db.invalidate(tid, d, version=version)
+
     # The following are for compatibility with protocol version 2.0.0
 
     def invalidateTrans(self, args):
@@ -1315,11 +1342,3 @@
     invalidate = invalidateVerify
     end = endVerify
     Invalidate = invalidateTrans
-
-def InvalidationLogIterator(fileobj):
-    unpickler = cPickle.Unpickler(fileobj)
-    while 1:
-        oid, version = unpickler.load()
-        if oid is None:
-            break
-        yield oid, version, None

Copied: ZODB/branches/3.8/src/ZEO/tests/invalidations_while_connecting.test (from rev 88351, ZODB/branches/jim-3.8-connection/src/ZEO/tests/invalidations_while_connecting.test)
===================================================================
--- ZODB/branches/3.8/src/ZEO/tests/invalidations_while_connecting.test	                        (rev 0)
+++ ZODB/branches/3.8/src/ZEO/tests/invalidations_while_connecting.test	2008-07-14 14:59:11 UTC (rev 88353)
@@ -0,0 +1,102 @@
+Invalidations while connecting
+==============================
+
+As soon as a client registers with a server, it will recieve
+invalidations from the server.  The client must be careful to queue
+these invalidations until it is ready to deal with them.  At the time
+of the writing of this test, clients weren't careful enogh about
+queing invalidations.  This led to cache corruption in the form of
+both low-level file corruption as well as out-of-date records marked
+as current.
+
+This tests tries to provoke this bug by:
+
+- starting a server
+
+    >>> import ZEO.tests.testZEO, ZEO.tests.forker
+    >>> addr = 'localhost', ZEO.tests.testZEO.get_port()
+    >>> zconf = ZEO.tests.forker.ZEOConfig(addr)
+    >>> sconf = '<filestorage 1>\npath Data.fs\n</filestorage>\n'
+    >>> _, adminaddr, pid, conf_path = ZEO.tests.forker.start_zeo_server(
+    ...     sconf, zconf, addr[1])
+
+- opening a client to the server that writes some objects, filling
+  it's cache at the same time,
+
+    >>> import ZEO.ClientStorage, ZODB.tests.MinPO, transaction
+    >>> db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x'))
+    >>> conn = db.open()
+    >>> nobs = 1000
+    >>> for i in range(nobs):
+    ...     conn.root()[i] = ZODB.tests.MinPO.MinPO(0)
+    >>> transaction.commit()
+
+- disconnecting the first client (closing it with a persistent cache),
+
+    >>> db.close()
+
+- starting a second client that writes objects more or less
+  constantly,
+
+    >>> import random, threading
+    >>> stop = False
+    >>> db2 = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr))
+    >>> tm = transaction.TransactionManager()
+    >>> conn2 = db2.open(transaction_manager=tm)
+    >>> random = random.Random(0)
+    >>> lock = threading.Lock()
+    >>> def run():
+    ...     while 1:
+    ...         i = random.randint(0, nobs-1)
+    ...         if stop:
+    ...             return
+    ...         lock.acquire()
+    ...         try:
+    ...             conn2.root()[i].value += 1
+    ...             tm.commit()
+    ...         finally:
+    ...             lock.release()
+    ...             time.sleep(0)
+    >>> thread = threading.Thread(target=run)
+    >>> thread.start()
+
+- restarting the first client, and 
+- testing for cache validity.
+
+    >>> import zope.testing.loggingsupport, logging
+    >>> handler = zope.testing.loggingsupport.InstalledHandler(
+    ...    'ZEO', level=logging.ERROR)
+
+    >>> import time
+    >>> for c in range(10):
+    ...    time.sleep(.1)
+    ...    db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x'))
+    ...    _ = lock.acquire()
+    ...    try:
+    ...      time.sleep(.1)
+    ...      assert (db._storage.lastTransaction()
+    ...              == db._storage._server.lastTransaction()), (
+    ...                  db._storage.lastTransaction(),
+    ...                  db._storage._server.lastTransactiion())
+    ...      conn = db.open()
+    ...      for i in range(1000):
+    ...        if conn.root()[i].value != conn2.root()[i].value:
+    ...            print 'bad', c, i, conn.root()[i].value,
+    ...            print  conn2.root()[i].value
+    ...    finally:
+    ...      _ = lock.release()
+    ...    db.close()
+
+    >>> stop = True
+    >>> thread.join(10)
+    >>> thread.isAlive()
+    False
+
+    >>> for record in handler.records:
+    ...     print record.name, record.levelname
+    ...     print handler.format(record)
+
+    >>> handler.uninstall()
+
+    >>> db.close()
+    >>> db2.close()

Modified: ZODB/branches/3.8/src/ZEO/tests/testConnection.py
===================================================================
--- ZODB/branches/3.8/src/ZEO/tests/testConnection.py	2008-07-14 14:59:06 UTC (rev 88352)
+++ ZODB/branches/3.8/src/ZEO/tests/testConnection.py	2008-07-14 14:59:11 UTC (rev 88353)
@@ -21,8 +21,8 @@
 import unittest
 # Import the actual test class
 from ZEO.tests import ConnectionTests, InvalidationTests
+from zope.testing import doctest, setupstack
 
-
 class FileStorageConfig:
     def getConfig(self, path, create, read_only):
         return """\
@@ -135,6 +135,10 @@
     for klass in test_classes:
         sub = unittest.makeSuite(klass, 'check')
         suite.addTest(sub)
+    suite.addTest(doctest.DocFileSuite(
+        'invalidations_while_connecting.test',
+        setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown,
+        ))
     return suite
 
 

Modified: ZODB/branches/3.8/src/ZEO/zrpc/client.py
===================================================================
--- ZODB/branches/3.8/src/ZEO/zrpc/client.py	2008-07-14 14:59:06 UTC (rev 88352)
+++ ZODB/branches/3.8/src/ZEO/zrpc/client.py	2008-07-14 14:59:11 UTC (rev 88353)
@@ -447,8 +447,7 @@
         Call the client's testConnection(), giving the client a chance
         to do app-level check of the connection.
         """
-        self.conn = ManagedClientConnection(self.sock, self.addr,
-                                            self.client, self.mgr)
+        self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr)
         self.sock = None # The socket is now owned by the connection
         try:
             self.preferred = self.client.testConnection(self.conn)

Modified: ZODB/branches/3.8/src/ZEO/zrpc/connection.py
===================================================================
--- ZODB/branches/3.8/src/ZEO/zrpc/connection.py	2008-07-14 14:59:06 UTC (rev 88352)
+++ ZODB/branches/3.8/src/ZEO/zrpc/connection.py	2008-07-14 14:59:11 UTC (rev 88353)
@@ -555,14 +555,23 @@
             self.replies_cond.release()
 
     def handle_request(self, msgid, flags, name, args):
-        if not self.check_method(name):
-            msg = "Invalid method name: %s on %s" % (name, repr(self.obj))
+        obj = self.obj
+        
+        if name.startswith('_') or not hasattr(obj, name):
+            if obj is None:
+                if __debug__:
+                    self.log("no object calling %s%s"
+                             % (name, short_repr(args)),
+                             level=logging.DEBUG)
+                return
+                
+            msg = "Invalid method name: %s on %s" % (name, repr(obj))
             raise ZRPCError(msg)
         if __debug__:
             self.log("calling %s%s" % (name, short_repr(args)),
                      level=logging.DEBUG)
 
-        meth = getattr(self.obj, name)
+        meth = getattr(obj, name)
         try:
             self.waiting_for_reply = True
             try:
@@ -601,12 +610,6 @@
                  level=logging.ERROR, exc_info=True)
         self.close()
 
-    def check_method(self, name):
-        # TODO:  This is hardly "secure".
-        if name.startswith('_'):
-            return None
-        return hasattr(self.obj, name)
-
     def send_reply(self, msgid, ret):
         # encode() can pass on a wide variety of exceptions from cPickle.
         # While a bare `except` is generally poor practice, in this case
@@ -897,7 +900,7 @@
     __super_close = Connection.close
     base_message_output = Connection.message_output
 
-    def __init__(self, sock, addr, obj, mgr):
+    def __init__(self, sock, addr, mgr):
         self.mgr = mgr
 
         # We can't use the base smac's message_output directly because the
@@ -914,7 +917,7 @@
         self.queue_output = True
         self.queued_messages = []
 
-        self.__super_init(sock, addr, obj, tag='C', map=client_map)
+        self.__super_init(sock, addr, None, tag='C', map=client_map)
         self.thr_async = True
         self.trigger = client_trigger
         client_trigger.pull_trigger()



More information about the Zodb-checkins mailing list