[Zodb-checkins] CVS: Packages/ZEO - ClientStorage.py:1.26.4.16

jeremy@digicool.com jeremy@digicool.com
Tue, 1 May 2001 14:39:29 -0400 (EDT)


Update of /cvs-repository/Packages/ZEO
In directory korak:/tmp/cvs-serv2484

Modified Files:
      Tag: ZEO-ZRPC-Dev
	ClientStorage.py 
Log Message:
Tentative change: Get rid of the ThreadLock. There are not re-entrant
calls.  Furthermore, the condition variable around self._transaction
plus a reasonable client should provide everything that is needed.



--- Updated File ClientStorage.py in package Packages/ZEO --
--- ClientStorage.py	2001/04/27 20:57:18	1.26.4.15
+++ ClientStorage.py	2001/05/01 18:39:27	1.26.4.16
@@ -97,6 +97,7 @@
 import sys
 import tempfile
 import thread
+import threading
 import time
 from types import TupleType, StringType
 from struct import pack, unpack
@@ -183,20 +184,30 @@
 
         self.__name__ = name
 
+        # A ClientStorage only allows one client to commit at a time.
+        # A client enters the commit state by finding tpc_tid set to
+        # None and updating it to the new transaction's id.  The
+        # tpc_tid variable is protected by tpc_cond.
+        self.tpc_cond = threading.Condition()
+        self._transaction = None
+
+        # Cache synchronization
+        # We need to make sure the cache isn't accessed in an
+        # inconsistent state.  If one client is doing a load while
+        # another is committing a transaction, the cache could contain
+        # partial results for the committing transaction.  Thus, there
+        # needs to be locking so that only one thread is
+        # reading/writing the cache at a time.
+        self.cache_lock = threading.Lock()
+
         commit_lock = thread.allocate_lock()
         self._commit_lock_acquire = commit_lock.acquire
         self._commit_lock_release = commit_lock.release
 
-        # What's the difference between thread and ThreadLock?
-        l = ThreadLock.allocate_lock()
-        self._lock_acquire = l.acquire
-        self._lock_release = l.release
-
         t = time.time()
         t = self._ts = apply(TimeStamp,(time.gmtime(t)[:5]+(t%60,)))
         self._serial = `t`
         self._oid='\0\0\0\0\0\0\0\0'
-        self._transaction = None
 
     def registerDB(self, db, limit):
         """Register that the storage is controlled by the given DB."""
@@ -204,28 +215,19 @@
         self._db = db
 
     def is_connected(self):
-        self._lock_acquire()
-        try:
-            if self._server:
-                return 1
-            else:
-                return 0
-        finally:
-            self._lock_release()
+        if self._server:
+            return 1
+        else:
+            return 0
 
     def notifyConnected(self, c):
         log2(INFO, "Connected to storage")
-        self._lock_acquire()
-        try:
-            self._server = ServerStub.StorageServer(c)
+        self._server = ServerStub.StorageServer(c)
 
-            self._oids = []
-
-            self._server.register(str(self._storage))
-            self.verify_cache()
+        self._oids = []
 
-        finally:
-            self._lock_release()
+        self._server.register(str(self._storage))
+        self.verify_cache()
 
     def verify_cache(self):
         self._server.beginZeoVerify()
@@ -247,113 +249,134 @@
     def notifyDisconnected(self, ignored):
         log2(PROBLEM, "Disconnected from storage")
         self._transaction = None
-        try:
-            self._commit_lock_release()
-        except:
-            pass
+        if self._transaction:
+            self._transaction = None
+            self.tpc_cond.notifyAll()
+            self.tpc_cond.release()
 
     def __len__(self):
         return self._info['length']
 
-    def abortVersion(self, src, transaction):
-        if transaction is not self._transaction:
-            raise POSException.StorageTransactionError(self, transaction)
-        self._lock_acquire()
-        try:
-            oids = self._server.abortVersion(src, self._serial)
-            invalidate = self._cache.invalidate
-            for oid in oids:
-                invalidate(oid, src)
-            return oids
-        finally:
-            self._lock_release()
-
-    def close(self):
-        self._lock_acquire()
-        try:
-            # Close the manager first, so that it doesn't attempt to
-            # re-open the connection. 
-            self._rpc_mgr.close()
-            self._server.rpc.close()
-        finally:
-            self._lock_release()
-        
-    def commitVersion(self, src, dest, transaction):
-        if transaction is not self._transaction:
-            raise POSException.StorageTransactionError(self, transaction)
-        self._lock_acquire()
-        try:
-            oids = self._server.commitVersion(src, dest, self._serial)
-            invalidate = self._cache.invalidate
-            if dest:
-                # just invalidate our version data
-                for oid in oids:
-                    invalidate(oid, src)
-            else:
-                # dest is '', so invalidate version and non-version
-                for oid in oids:
-                    invalidate(oid, dest)
-            return oids
-        finally:
-            self._lock_release()
-
     def getName(self):
         return "%s (%s)" % (self.__name__, "XXX")
 
     def getSize(self):
         return self._info['size']
                   
-    def history(self, oid, version, length=1):
-        self._lock_acquire()
+    def supportsUndo(self):
+        return self._info['supportsUndo']
+    
+    def supportsVersions(self):
+        return self._info['supportsVersions']
+
+    def supportsTransactionalUndo(self):
+        return self._info['supportsTransactionalUndo']
+
+    def _check_trans(self, trans, exc=None):
+        if self._transaction is not trans:
+            if exc is None:
+                return 0
+            else:
+                raise exc(self._transaction, trans)
+        return 1
+        
+    def _check_tid(self, tid, exc=None):
+        # XXX Is all this locking unnecessary?  The only way to
+        # begin a transaction is to call tpc_begin().  If we assume
+        # clients are single-threaded and well-behaved, i.e. they call
+        # tpc_begin() first, then there appears to be no need for
+        # locking.  If _check_tid() is called and self.tpc_tid != tid,
+        # then there is no way it can be come equal during the call.
+        # Thus, there should be no race.
+        
+        if self.tpc_tid != tid:
+            if exc is None:
+                return 0
+            else:
+                raise exc(self.tpc_tid, tid)
+        return 1
+
+        # XXX But I'm not sure
+        
+        self.tpc_cond.acquire()
         try:
-            return self._server.history(oid, version, length)     
+            if self.tpc_tid != tid:
+                if exc is None:
+                    return 0
+                else:
+                    raise exc(self.tpc_tid, tid)
+            return 1
         finally:
-            self._lock_release()       
+            self.tpc_cond.release()
+
+    def abortVersion(self, src, transaction):
+        self._check_trans(transaction,
+                          POSException.StorageTransactionError)
+        oids = self._server.abortVersion(src, self._serial)
+        invalidate = self._cache.invalidate
+        for oid in oids:
+            invalidate(oid, src)
+        return oids
+
+    def close(self):
+        # Close the manager first, so that it doesn't attempt to
+        # re-open the connection. 
+        self._rpc_mgr.close()
+        self._server.rpc.close()
+        
+    def commitVersion(self, src, dest, transaction):
+        self._check_trans(transaction,
+                          POSException.StorageTransactionError)
+        oids = self._server.commitVersion(src, dest, self._serial)
+        invalidate = self._cache.invalidate
+        if dest:
+            # just invalidate our version data
+            for oid in oids:
+                invalidate(oid, src)
+        else:
+            # dest is '', so invalidate version and non-version
+            for oid in oids:
+                invalidate(oid, dest)
+        return oids
+
+    def history(self, oid, version, length=1):
+        # XXX sync
+        return self._server.history(oid, version, length)     
                   
     def loadSerial(self, oid, serial):
-        self._lock_acquire()
-        try:
-            return self._server.loadSerial(oid, serial)     
-        finally:
-            self._lock_release()       
+        # XXX sync
+        return self._server.loadSerial(oid, serial)     
 
     def load(self, oid, version, _stuff=None):
-        self._lock_acquire()
-        try:
-            p = self._cache.load(oid, version)
-            if p:
-                return p
-            p, s, v, pv, sv = self._server.zeoLoad(oid)
-            self._cache.checkSize(0)
-            self._cache.store(oid, p, s, v, pv, sv)
-            if v and version and v == version:
-                return pv, sv
-            else:
-                if s:
-                    return p, s
-                raise KeyError, oid # no non-version data for this
-        finally:
-            self._lock_release()
+        # XXX sync
+        p = self._cache.load(oid, version)
+        if p:
+            return p
+        p, s, v, pv, sv = self._server.zeoLoad(oid)
+        self._cache.checkSize(0)
+        self._cache.store(oid, p, s, v, pv, sv)
+        if v and version and v == version:
+            return pv, sv
+        else:
+            if s:
+                return p, s
+            raise KeyError, oid # no non-version data for this
                     
     def modifiedInVersion(self, oid):
-        self._lock_acquire()
-        try:
-            v = self._cache.modifiedInVersion(oid)
-            if v is not None:
-                return v
-            return self._server.modifiedInVersion(oid)
-        finally:
-            self._lock_release()
+        # XXX sync
+        v = self._cache.modifiedInVersion(oid)
+        if v is not None:
+            return v
+        return self._server.modifiedInVersion(oid)
 
     def new_oid(self, last=None):
-        self._lock_acquire()
-        try:
-            if not self._oids:
-                self._oids = self._server.new_oids()
-                self._oids.reverse()
-            return self._oids.pop()
-        finally:
-            self._lock_release()
+##        if self._transaction is None:
+##            # XXX What exception?
+##            raise POSException.StorageTransactionError()
+        if not self._oids:
+            self._oids = self._server.new_oids()
+            self._oids.reverse()
+        return self._oids.pop()
         
     def pack(self, t=None, rf=None, wait=0, days=0):
         # Note that we ignore the rf argument.  The server
@@ -361,11 +384,7 @@
         if t is None:
             t = time.time()
         t = t - (days * 86400)
-        self._lock_acquire()
-        try:
-            return self._server.pack(t, wait)
-        finally:
-            self._lock_release()
+        return self._server.pack(t, wait)
 
     def _check_serials(self):
         if self._serials:
@@ -379,181 +398,124 @@
             return r
 
     def store(self, oid, serial, data, version, transaction):
-        if transaction is not self._transaction:
-            raise POSException.StorageTransactionError(self, transaction)
-        self._lock_acquire()
-        try:
-            self._server.storea(oid, serial, data, version, self._serial) 
-            self._tbuf.store(oid, version, data)
-            return self._check_serials()
-        finally:
-            self._lock_release()
+        self._check_trans(transaction, POSException.StorageTransactionError)
+        self._server.storea(oid, serial, data, version, self._serial) 
+        self._tbuf.store(oid, version, data)
+        return self._check_serials()
 
     def tpc_vote(self, transaction):
         if transaction is not self._transaction:
-            raise POSException.StorageTransactionError(self, transaction)
-        self._lock_acquire()
-        try:
-            self._server.vote(self._serial)
-            return self._check_serials()
-        finally:
-            self._lock_release()
+            return
+        self._server.vote(self._serial)
+        return self._check_serials()
             
-    def supportsUndo(self):
-        return self._info['supportsUndo']
-    
-    def supportsVersions(self):
-        return self._info['supportsVersions']
-
-    def supportsTransactionalUndo(self):
-        return self._info['supportsTransactionalUndo']
-        
     def tpc_abort(self, transaction):
-        self._lock_acquire()
-        try:
-            if transaction is not self._transaction:
-                return
-            self._server.tpc_abort(self._serial)
-            self._transaction = None
-            self._tbuf.clear()
-            self._seriald.clear()
-            del self._serials[:]
-            self._commit_lock_release()
-        finally:
-            self._lock_release()
+        if transaction is not self._transaction:
+            return
+        self._server.tpc_abort(self._serial)
+        self._tbuf.clear()
+        self._seriald.clear()
+        del self._serials[:]
+        self._transaction = None
+        self.tpc_cond.notify()
+        self.tpc_cond.release()
 
     def tpc_begin(self, transaction):
         # XXX plan is to have begin be a local operation until the
-        # vote stage.  
-        self._lock_acquire()
-        try:
-            if self._transaction is transaction:
-                return # can start the same transaction many times
-            self._lock_release()
-            self._commit_lock_acquire()
-            self._lock_acquire()
-
-            self._ts = get_timestamp(self._ts)
-            id = `self._ts`
+        # vote stage.
+        self.tpc_cond.acquire()
+        while self._transaction is not None:
+            if self._transaction == transaction:
+                self.tpc_cond.release()
+                return
+            self.tpc_cond.wait()
+            
+        self._ts = get_timestamp(self._ts)
+        id = `self._ts`
+        self._transaction = transaction
 
-            try:
-                r = self._server.tpc_begin(id,
-                                           transaction.user,
-                                           transaction.description,
-                                           transaction._extension)
-            except:
-                self._commit_lock_release()
-                raise
-
-            assert r is None
-
-            # We have *BOTH* the local and distributed commit
-            # lock, now we can actually get ready to get started.
-            self._serial = id
-            self._seriald.clear()
-            del self._serials[:]
+        try:
+            r = self._server.tpc_begin(id,
+                                       transaction.user,
+                                       transaction.description,
+                                       transaction._extension)
+        except:
+            self.tpc_cond.release()
+            raise
 
-            self._transaction = transaction
-        finally:
-            self._lock_release()
+        self._serial = id
+        self._seriald.clear()
+        del self._serials[:]
 
     def tpc_finish(self, transaction, f=None):
-        self._lock_acquire()
-        try:
-            if transaction is not self._transaction:
-                return
-            if f is not None: # XXX what is f()?
-                f()
+        if transaction is not self._transaction:
+            return
+        if f is not None: # XXX what is f()?
+            f()
 
-            self._server.tpc_finish(self._serial)
+        self._server.tpc_finish(self._serial)
 
-            r = self._check_serials()
-            assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
+        r = self._check_serials()
+        assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
 
-            self._cache.checkSize(self._tbuf.get_size())
-
-            # Iterate over the objects in the transaction buffer and
-            # update or invalidate the cache. 
-            self._tbuf.begin_iterate()
-            while 1:
-                try:
-                    t = self._tbuf.next()
-                except ValueError, msg:
-                    raise ClientStorageError, (
-                        "Unexpected error reading temporary file in "
-                        "client storage: %s" % msg)
-                if t is None:
-                    break
-                oid, v, p = t
-                s = self._seriald[oid]
-                if type(s) != StringType:
-                    log2(INFO, "bad serialno: %s for %s" % \
-                        (repr(s), repr(oid)))
-                assert type(s) == StringType, "bad serialno: %s" % repr(s)
-                if s == ResolvedSerial:
-                    self._cache.invalidate(oid, v)
-                else:
-                    self._cache.update(oid, s, v, p)
-            self._tbuf.clear()
+        self._cache.checkSize(self._tbuf.get_size())
 
-            self._transaction=None
-            self._commit_lock_release()
-        finally: self._lock_release()
+        # Iterate over the objects in the transaction buffer and
+        # update or invalidate the cache. 
+        self._tbuf.begin_iterate()
+        while 1:
+            try:
+                t = self._tbuf.next()
+            except ValueError, msg:
+                raise ClientStorageError, (
+                    "Unexpected error reading temporary file in "
+                    "client storage: %s" % msg)
+            if t is None:
+                break
+            oid, v, p = t
+            s = self._seriald[oid]
+            if type(s) != StringType:
+                log2(INFO, "bad serialno: %s for %s" % \
+                    (repr(s), repr(oid)))
+            assert type(s) == StringType, "bad serialno: %s" % repr(s)
+            if s == ResolvedSerial:
+                self._cache.invalidate(oid, v)
+            else:
+                self._cache.update(oid, s, v, p)
+        self._tbuf.clear()
 
+        self._transaction = None
+        self.tpc_cond.notify()
+        self.tpc_cond.release()
+
     def transactionalUndo(self, trans_id, trans):
-        if trans is not self._transaction:
-            raise POSException.StorageTransactionError(self._transaction,
-                                                       transaction)
-        self._lock_acquire()
-        try:
-            oids = self._server.transactionalUndo(trans_id, self._serial)
-            for oid in oids:
-                self._cache.invalidate(oid, '')
-            return oids
-        finally:
-            self._lock_release()
+        self._check_trans(trans, POSException.StorageTransactionError)
+        oids = self._server.transactionalUndo(trans_id, self._serial)
+        for oid in oids:
+            self._cache.invalidate(oid, '')
+        return oids
 
     def undo(self, transaction_id):
-        self._lock_acquire()
-        try:
-            oids = self._server.undo(transaction_id)
-            cinvalidate = self._cache.invalidate
-            for oid in oids:
-                cinvalidate(oid,'')                
-            return oids
-        finally: self._lock_release()
+        oids = self._server.undo(transaction_id)
+        cinvalidate = self._cache.invalidate
+        for oid in oids:
+            cinvalidate(oid, '')                
+        return oids
 
-
     def undoInfo(self, first=0, last=-20, specification=None):
-        self._lock_acquire()
-        try:
-            return self._server.undoInfo(first, last, specification)
-        finally:
-            self._lock_release()
+        return self._server.undoInfo(first, last, specification)
 
     def undoLog(self, first, last, filter=None):
         if filter is not None:
             return () # XXX can't pass a filter to server
         
-        self._lock_acquire()
-        try:
-            return self._server.undoLog(first, last) # Eek!
-        finally:
-            self._lock_release()
+        return self._server.undoLog(first, last) # Eek!
 
     def versionEmpty(self, version):
-        self._lock_acquire()
-        try:
-            return self._server.versionEmpty(version)
-        finally:
-            self._lock_release()
+        return self._server.versionEmpty(version)
 
     def versions(self, max=None):
-        self._lock_acquire()
-        try:
-            return self._server.versions(max)
-        finally:
-            self._lock_release()
+        return self._server.versions(max)
 
     # below are methods invoked by the StorageServer