[Zope-Checkins] CVS: ZODB3/ZEO - ClientStorage.py:1.73.2.24

Jeremy Hylton jeremy@zope.com
Wed, 11 Jun 2003 19:12:03 -0400


Update of /cvs-repository/ZODB3/ZEO
In directory cvs.zope.org:/tmp/cvs-serv6722/ZEO

Modified Files:
      Tag: ZODB3-3_1-branch
	ClientStorage.py 
Log Message:
Merge tim-loading_oids_status-branch to ZODB 3.1 release branch.

Includes four critical bug fixes.


=== ZODB3/ZEO/ClientStorage.py 1.73.2.23 => 1.73.2.24 ===
--- ZODB3/ZEO/ClientStorage.py:1.73.2.23	Fri Jun  6 13:58:10 2003
+++ ZODB3/ZEO/ClientStorage.py	Wed Jun 11 19:11:32 2003
@@ -262,6 +262,21 @@
         self._oid_lock = threading.Lock()
         self._oids = [] # Object ids retrieved from new_oids()
 
+        # There's a nasty race.  The ZRPC layer can deliver invalidations
+        # out of order (i.e., the server sends the result of a load, then
+        # sends an invalidation for that object, but we see the invalidation
+        # first).  To worm around this, load() creates a dict in
+        # _loading_oid_invs and invalidateTrans() stores any arriving
+        # invalidations in that dict.  When the zeoLoad() call returns
+        # from the server, load() handles the invalidations.
+        # It's possible for different threads to attempt to load the same
+        # oid at the same time.  To account for this, we keep a counter
+        # in _loading_oid_count.  Invalidations are only handled when
+        # the count reaches zero.  Mutations of these dicts are protected
+        # by _lock.
+        self._loading_oid_count = {}  # oid -> count of pending load()s
+        self._loading_oid_invs = {}   # oid -> set of invalidated versions
+
         # Can't read data in one thread while writing data
         # (tpc_finish) in another thread.  In general, the lock
         # must prevent access to the cache while _update_cache
@@ -614,13 +629,31 @@
         if self._server is None:
             raise ClientDisconnected()
 
-        # If an invalidation for oid comes in during zeoLoad, that's OK
-        # because we'll get oid's new state.
+        self._incLoadStatus(oid)
+
+        try:
+            p, s, v, pv, sv = self._server.zeoLoad(oid)
+        except:
+            self._lock.acquire()
+            try:
+                self._decLoadStatus(oid)
+            finally:
+                self._lock.release()
+            raise
+
+        self._lock.acquire()
+        try:
+            is_last, invs = self._decLoadStatus(oid)
+            if is_last:
+                if invs:
+                    for v in invs:
+                        self._cache.invalidate(oid, v)
+                else:
+                    self._cache.checkSize(0)
+                    self._cache.store(oid, p, s, v, pv, sv)
+        finally:
+            self._lock.release()
 
-        # XXX Race condition among load / invalid / store in cache
-        p, s, v, pv, sv = self._server.zeoLoad(oid)
-        self._cache.checkSize(0)
-        self._cache.store(oid, p, s, v, pv, sv)
         if v and version and v == version:
             return pv, sv
         else:
@@ -628,6 +661,42 @@
                 return p, s
             raise KeyError, oid # no non-version data for this
 
+    def _incLoadStatus(self, oid):
+        """Increment the load count for oid, version pair.
+
+        Does its own locking.
+        """
+        self._lock.acquire()
+        try:
+            count = self._loading_oid_count.get(oid)
+            if count is None:
+                count = 0
+                self._loading_oid_invs[oid] = {}
+            count += 1
+            self._loading_oid_count[oid] = count
+        finally:
+            self._lock.release()
+
+    def _decLoadStatus(self, oid):
+        """Decrement load count.
+
+        Return boolean indicating whether this was the last load and
+        a list of versions to invalidate.  The list is empty unless
+        the boolean is True.
+
+        Caller must hold self._lock.
+        """
+        count = self._loading_oid_count[oid]
+        count -= 1
+        if count:
+            self._loading_oid_count[oid] = count
+            return 0, []
+        else:
+            del self._loading_oid_count[oid]
+            d = self._loading_oid_invs[oid]
+            del self._loading_oid_invs[oid]
+            return 1, d.keys()
+
     def modifiedInVersion(self, oid):
         """Storage API: return the version, if any, that modfied an object.
 
@@ -919,10 +988,23 @@
         try:
             # versions maps version names to dictionary of invalidations
             versions = {}
-            for oid, version in invs:
+            for pair in invs:
+                oid, version = pair
                 d = versions.setdefault(version, {})
-                self._cache.invalidate(oid, version=version)
                 d[oid] = 1
+
+                # Update the _loading_oids_invs dict for this oid,
+                # if necessary.
+                d = self._loading_oid_invs.get(oid)
+                if d is not None:
+                    # load() is waiting for this.  Mark the version
+                    # as invalidated, so that load can invalidate it
+                    # later.
+                    d[version] = 1
+                else:
+                    # load() isn't waiting for this.  Simply invalidate it.
+                    self._cache.invalidate(oid, version=version)
+
             if self._db is not None:
                 for v, d in versions.items():
                     self._db.invalidate(d, version=v)