[Zope-Checkins] CVS: ZODB3/ZEO - ClientStorage.py:1.73.2.25 ClientCache.py:1.38.2.3

Jeremy Hylton jeremy@zope.com
Thu, 12 Jun 2003 16:31:21 -0400


Update of /cvs-repository/ZODB3/ZEO
In directory cvs.zope.org:/tmp/cvs-serv25917/ZEO

Modified Files:
      Tag: ZODB3-3_1-branch
	ClientStorage.py ClientCache.py 
Log Message:
Partial replacement fix: Use a lock to serialize load() and tpc_finish().


=== ZODB3/ZEO/ClientStorage.py 1.73.2.24 => 1.73.2.25 ===
--- ZODB3/ZEO/ClientStorage.py:1.73.2.24	Wed Jun 11 19:11:32 2003
+++ ZODB3/ZEO/ClientStorage.py	Thu Jun 12 16:31:21 2003
@@ -262,20 +262,14 @@
         self._oid_lock = threading.Lock()
         self._oids = [] # Object ids retrieved from new_oids()
 
-        # There's a nasty race.  The ZRPC layer can deliver invalidations
-        # out of order (i.e., the server sends the result of a load, then
-        # sends an invalidation for that object, but we see the invalidation
-        # first).  To worm around this, load() creates a dict in
-        # _loading_oid_invs and invalidateTrans() stores any arriving
-        # invalidations in that dict.  When the zeoLoad() call returns
-        # from the server, load() handles the invalidations.
-        # It's possible for different threads to attempt to load the same
-        # oid at the same time.  To account for this, we keep a counter
-        # in _loading_oid_count.  Invalidations are only handled when
-        # the count reaches zero.  Mutations of these dicts are protected
-        # by _lock.
-        self._loading_oid_count = {}  # oid -> count of pending load()s
-        self._loading_oid_invs = {}   # oid -> set of invalidated versions
+        # load() and tpc_finish() must be serialized to guarantee
+        # that cache modifications from each occur atomically.
+        # It also prevents multiple load calls occuring simultaneously,
+        # which simplifies the cache logic.
+        self._load_lock = threading.Lock()
+        # _load_oid and _load_status are protected by _lock
+        self._load_oid = None
+        self._load_status = None
 
         # Can't read data in one thread while writing data
         # (tpc_finish) in another thread.  In general, the lock
@@ -620,39 +614,36 @@
         """
         self._lock.acquire()    # for atomic processing of invalidations
         try:
-            p = self._cache.load(oid, version)
-            if p:
-                return p
+            pair = self._cache.load(oid, version)
+            if pair:
+                return pair
         finally:
             self._lock.release()
 
         if self._server is None:
             raise ClientDisconnected()
 
-        self._incLoadStatus(oid)
-
+        self._load_lock.acquire()
         try:
-            p, s, v, pv, sv = self._server.zeoLoad(oid)
-        except:
             self._lock.acquire()
             try:
-                self._decLoadStatus(oid)
+                self._load_oid = oid
+                self._load_status = True
             finally:
                 self._lock.release()
-            raise
 
-        self._lock.acquire()
-        try:
-            is_last, invs = self._decLoadStatus(oid)
-            if is_last:
-                if invs:
-                    for v in invs:
-                        self._cache.invalidate(oid, v)
-                else:
+            p, s, v, pv, sv = self._server.zeoLoad(oid)
+
+            self._lock.acquire()    # for atomic processing of invalidations
+            try:
+                if self._load_status:
                     self._cache.checkSize(0)
                     self._cache.store(oid, p, s, v, pv, sv)
+                self._load_oid = None
+            finally:
+                self._lock.release()
         finally:
-            self._lock.release()
+            self._load_lock.release()
 
         if v and version and v == version:
             return pv, sv
@@ -661,42 +652,6 @@
                 return p, s
             raise KeyError, oid # no non-version data for this
 
-    def _incLoadStatus(self, oid):
-        """Increment the load count for oid, version pair.
-
-        Does its own locking.
-        """
-        self._lock.acquire()
-        try:
-            count = self._loading_oid_count.get(oid)
-            if count is None:
-                count = 0
-                self._loading_oid_invs[oid] = {}
-            count += 1
-            self._loading_oid_count[oid] = count
-        finally:
-            self._lock.release()
-
-    def _decLoadStatus(self, oid):
-        """Decrement load count.
-
-        Return boolean indicating whether this was the last load and
-        a list of versions to invalidate.  The list is empty unless
-        the boolean is True.
-
-        Caller must hold self._lock.
-        """
-        count = self._loading_oid_count[oid]
-        count -= 1
-        if count:
-            self._loading_oid_count[oid] = count
-            return 0, []
-        else:
-            del self._loading_oid_count[oid]
-            d = self._loading_oid_invs[oid]
-            del self._loading_oid_invs[oid]
-            return 1, d.keys()
-
     def modifiedInVersion(self, oid):
         """Storage API: return the version, if any, that modfied an object.
 
@@ -847,22 +802,22 @@
         """Storage API: finish a transaction."""
         if transaction is not self._transaction:
             return
+        self._load_lock.acquire()
         try:
             self._lock.acquire()  # for atomic processing of invalidations
             try:
                 self._update_cache()
+                if f is not None:
+                    f()
             finally:
                 self._lock.release()
 
-            if f is not None:
-                f()
-
             self._server.tpc_finish(self._serial)
 
             r = self._check_serials()
             assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
-
         finally:
+            self._load_lock.release()
             self.end_transaction()
 
     def _update_cache(self):
@@ -988,22 +943,11 @@
         try:
             # versions maps version names to dictionary of invalidations
             versions = {}
-            for pair in invs:
-                oid, version = pair
-                d = versions.setdefault(version, {})
-                d[oid] = 1
-
-                # Update the _loading_oids_invs dict for this oid,
-                # if necessary.
-                d = self._loading_oid_invs.get(oid)
-                if d is not None:
-                    # load() is waiting for this.  Mark the version
-                    # as invalidated, so that load can invalidate it
-                    # later.
-                    d[version] = 1
-                else:
-                    # load() isn't waiting for this.  Simply invalidate it.
-                    self._cache.invalidate(oid, version=version)
+            for oid, version in invs:
+                if oid == self._load_oid:
+                    self._load_status = False
+                self._cache.invalidate(oid, version=version)
+                versions.setdefault(version, {})[oid] = 1
 
             if self._db is not None:
                 for v, d in versions.items():


=== ZODB3/ZEO/ClientCache.py 1.38.2.2 => 1.38.2.3 ===
--- ZODB3/ZEO/ClientCache.py:1.38.2.2	Thu Nov 21 13:58:51 2002
+++ ZODB3/ZEO/ClientCache.py	Thu Jun 12 16:31:21 2003
@@ -516,7 +516,26 @@
         finally:
             self._release()
 
+    def _get_serial(self, oid):
+        pos = self._get(oid, None)
+        if pos is None:
+            return None
+        f = self._f[pos < 0]
+        # The cache header is 27 bytes long.  The last 8 bytes are the
+        # serialno.
+        f.seek(abs(pos) + 19)
+        return f.read(8)
+
     def _store(self, oid, p, s, version, pv, sv):
+        # Caller must acquire lock.
+        
+        # Make sure the serial number we are writing is greater than
+        # a serial number already in the cache.  This is a crucial
+        # invariant for cache consistency.
+        if __debug__:
+            oldserial = self._get_serial(oid)
+            assert oldserial is None or oldserial < (s or sv)
+        
         if not s:
             p = ''
             s = '\0\0\0\0\0\0\0\0'