[Zodb-checkins] CVS: ZODB4/src/zodb/zeo - schema.xml:1.1.2.1 runzeo.py:1.1.2.1 component.xml:1.1.2.1 stubs.py:1.7.6.1 server.py:1.12.2.1 interfaces.py:1.3.30.1 client.py:1.13.2.1

Jeremy Hylton jeremy at zope.com
Tue Jun 17 18:59:56 EDT 2003


Update of /cvs-repository/ZODB4/src/zodb/zeo
In directory cvs.zope.org:/tmp/cvs-serv10995/src/zodb/zeo

Modified Files:
      Tag: ZODB3-2-merge
	stubs.py server.py interfaces.py client.py 
Added Files:
      Tag: ZODB3-2-merge
	schema.xml runzeo.py component.xml 
Log Message:
Checkpoint progress merging ZODB 3.2 features and fixes into ZODB4.


=== Added File ZODB4/src/zodb/zeo/schema.xml ===
<schema>

  <description>
    This schema describes the configuration of the ZEO storage server
    process.
  </description>

  <!-- Use the storage types defined by ZODB. -->
  <import package="zodb"/>

  <!-- Use the ZEO server information structure. -->
  <import package="zodb/zeo"/>

  <section type="zeo" name="*" required="yes" attribute="zeo" />

  <multisection name="+" type="ZODB.storage"
                attribute="storages"
                required="yes">
    <description>
      One or more storages that are provided by the ZEO server.  The
      section names are used as the storage names, and must be unique
      within each ZEO storage server.  Traditionally, these names
      represent small integers starting at '1'.
    </description>
  </multisection>

  <section name="*" type="eventlog" attribute="eventlog" required="no" />

</schema>


=== Added File ZODB4/src/zodb/zeo/runzeo.py ===
#!python
##############################################################################
#
# Copyright (c) 2001, 2002, 2003 Zope Corporation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
"""Start the ZEO storage server.

Usage: %s [-C URL] [-a ADDRESS] [-f FILENAME] [-h]

Options:
-C/--configuration URL -- configuration file or URL
-a/--address ADDRESS -- server address of the form PORT, HOST:PORT, or PATH
                        (a PATH must contain at least one "/")
-f/--filename FILENAME -- filename for FileStorage
-t/--timeout TIMEOUT -- transaction timeout in secondes (default no timeout)
-h/--help -- print this usage message and exit
-m/--monitor ADDRESS -- address of monitor server ([HOST:]PORT or PATH)

Unless -C is specified, -a and -f are required.
"""

# The code here is designed to be reused by other, similar servers.
# For the forseeable future, it must work under Python 2.1 as well as
# 2.2 and above.

import os
import sys
import getopt
import signal
import socket
import logging

import ZConfig
from zdaemon.zdoptions import ZDOptions
from zodb import zeo

def parse_address(arg):
    # XXX Not part of the official ZConfig API
    obj = ZConfig.datatypes.SocketAddress(arg)
    return obj.family, obj.address

class ZEOOptionsMixin:

    storages = None

    def handle_address(self, arg):
        self.family, self.address = parse_address(arg)

    def handle_monitor_address(self, arg):
        self.monitor_family, self.monitor_address = parse_address(arg)

    def handle_filename(self, arg):
        from ZODB.config import FileStorage # That's a FileStorage *opener*!
        class FSConfig:
            def __init__(self, name, path):
                self._name = name
                self.path = path
                self.create = 0
                self.read_only = 0
                self.stop = None
                self.quota = None
            def getSectionName(self):
                return self._name
        if not self.storages:
            self.storages = []
        name = str(1 + len(self.storages))
        conf = FileStorage(FSConfig(name, arg))
        self.storages.append(conf)

    def add_zeo_options(self):
        self.add(None, None, "a:", "address=", self.handle_address)
        self.add(None, None, "f:", "filename=", self.handle_filename)
        self.add("family", "zeo.address.family")
        self.add("address", "zeo.address.address",
                 required="no server address specified; use -a or -C")
        self.add("read_only", "zeo.read_only", default=0)
        self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
                 default=100)
        self.add("transaction_timeout", "zeo.transaction_timeout",
                 "t:", "timeout=", float)
        self.add("monitor_address", "zeo.monitor_address", "m:", "monitor=",
                 self.handle_monitor_address)
        self.add('auth_protocol', 'zeo.authentication_protocol',
                 None, 'auth-protocol=', default=None)
        self.add('auth_database', 'zeo.authentication_database',
                 None, 'auth-database=')
        self.add('auth_realm', 'zeo.authentication_realm',
                 None, 'auth-realm=')

class ZEOOptions(ZDOptions, ZEOOptionsMixin):

    logsectionname = "eventlog"

    def __init__(self):
        self.schemadir = os.path.dirname(zeo.__file__)
        ZDOptions.__init__(self)
        self.add_zeo_options()
        self.add("storages", "storages",
                 required="no storages specified; use -f or -C")


class ZEOServer:

    def __init__(self, options):
        self.options = options

    def main(self):
        self.setup_default_logging()
        self.check_socket()
        self.clear_socket()
        try:
            self.open_storages()
            self.setup_signals()
            self.create_server()
            self.loop_forever()
        finally:
            self.close_storages()
            self.clear_socket()

    def setup_default_logging(self):
        if self.options.config_logger is not None:
            return
        if os.getenv("EVENT_LOG_FILE") is not None:
            return
        if os.getenv("STUPID_LOG_FILE") is not None:
            return
        # No log file is configured; default to stderr.  The logging
        # level can still be controlled by {STUPID,EVENT}_LOG_SEVERITY.
        os.environ["EVENT_LOG_FILE"] = ""

    def check_socket(self):
        if self.can_connect(self.options.family, self.options.address):
            self.options.usage("address %s already in use" %
                               repr(self.options.address))

    def can_connect(self, family, address):
        s = socket.socket(family, socket.SOCK_STREAM)
        try:
            s.connect(address)
        except socket.error:
            return 0
        else:
            s.close()
            return 1

    def clear_socket(self):
        if isinstance(self.options.address, type("")):
            try:
                os.unlink(self.options.address)
            except os.error:
                pass

    def open_storages(self):
        self.storages = {}
        for opener in self.options.storages:
            _logger.info("opening storage %r using %s"
                 % (opener.name, opener.__class__.__name__))
            self.storages[opener.name] = opener.open()

    def setup_signals(self):
        """Set up signal handlers.

        The signal handler for SIGFOO is a method handle_sigfoo().
        If no handler method is defined for a signal, the signal
        action is not changed from its initial value.  The handler
        method is called without additional arguments.
        """
        if os.name != "posix":
            return
        if hasattr(signal, 'SIGXFSZ'):
            signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
        init_signames()
        for sig, name in signames.items():
            method = getattr(self, "handle_" + name.lower(), None)
            if method is not None:
                def wrapper(sig_dummy, frame_dummy, method=method):
                    method()
                signal.signal(sig, wrapper)

    def create_server(self):
        from zodb.zeo.server import StorageServer
        self.server = StorageServer(
            self.options.address,
            self.storages,
            read_only=self.options.read_only,
            invalidation_queue_size=self.options.invalidation_queue_size,
            transaction_timeout=self.options.transaction_timeout,
            monitor_address=self.options.monitor_address,
            auth_protocol=self.options.auth_protocol,
            auth_database=self.options.auth_database,
            auth_realm=self.options.auth_realm)

    def loop_forever(self):
        import ThreadedAsync.LoopCallback
        ThreadedAsync.LoopCallback.loop()

    def handle_sigterm(self):
        _logger.info("terminated by SIGTERM")
        sys.exit(0)

    def handle_sigint(self):
        _logger.info("terminated by SIGINT")
        sys.exit(0)

    def handle_sighup(self):
        _logger.info("restarted by SIGHUP")
        sys.exit(1)

    def handle_sigusr2(self):
        # How should this work with new logging?
        
        # This requires a modern zLOG (from Zope 2.6 or later); older
        # zLOG packages don't have the initialize() method
        _logger.info("reinitializing zLOG")
        # XXX Shouldn't this be below with _log()?
        import zLOG
        zLOG.initialize()
        _logger.info("reinitialized zLOG")

    def close_storages(self):
        for name, storage in self.storages.items():
            _logger.info("closing storage %r" % name)
            try:
                storage.close()
            except: # Keep going
                _logging.exception("failed to close storage %r" % name)


# Signal names

signames = None

def signame(sig):
    """Return a symbolic name for a signal.

    Return "signal NNN" if there is no corresponding SIG name in the
    signal module.
    """

    if signames is None:
        init_signames()
    return signames.get(sig) or "signal %d" % sig

def init_signames():
    global signames
    signames = {}
    for name, sig in signal.__dict__.items():
        k_startswith = getattr(name, "startswith", None)
        if k_startswith is None:
            continue
        if k_startswith("SIG") and not k_startswith("SIG_"):
            signames[sig] = name


# Main program

def main(args=None):
    global _logger
    _logger = logging.getLogger("runzeo")

    options = ZEOOptions()
    options.realize(args)
    s = ZEOServer(options)
    s.main()

if __name__ == "__main__":
    main()


=== Added File ZODB4/src/zodb/zeo/component.xml ===
<component>

  <!-- stub out the type until we figure out how to zconfig logging -->
  <sectiontype name="eventlog" />

  <sectiontype name="zeo">

    <description>
      The content of a ZEO section describe operational parameters
      of a ZEO server except for the storage(s) to be served.
    </description>

    <key name="address" datatype="socket-address"
         required="yes">
      <description>
        The address at which the server should listen.  This can be in
        the form 'host:port' to signify a TCP/IP connection or a
        pathname string to signify a Unix domain socket connection (at
        least one '/' is required).  A hostname may be a DNS name or a
        dotted IP address.  If the hostname is omitted, the platform's
        default behavior is used when binding the listening socket (''
        is passed to socket.bind() as the hostname portion of the
        address).
      </description>
    </key>

    <key name="read-only" datatype="boolean"
         required="no"
         default="false">
      <description>
        Flag indicating whether the server should operate in read-only
        mode.  Defaults to false.  Note that even if the server is
        operating in writable mode, individual storages may still be
        read-only.  But if the server is in read-only mode, no write
        operations are allowed, even if the storages are writable.  Note
        that pack() is considered a read-only operation.
      </description>
    </key>

    <key name="invalidation-queue-size" datatype="integer"
         required="no"
         default="100">
      <description>
        The storage server keeps a queue of the objects modified by the
        last N transactions, where N == invalidation_queue_size.  This
        queue is used to speed client cache verification when a client
        disconnects for a short period of time.
      </description>
    </key>

    <key name="monitor-address" datatype="socket-address"
         required="no">
      <description>
        The address at which the monitor server should listen.  If
        specified, a monitor server is started.  The monitor server
        provides server statistics in a simple text format.  This can
        be in the form 'host:port' to signify a TCP/IP connection or a
        pathname string to signify a Unix domain socket connection (at
        least one '/' is required).  A hostname may be a DNS name or a
        dotted IP address.  If the hostname is omitted, the platform's
        default behavior is used when binding the listening socket (''
        is passed to socket.bind() as the hostname portion of the
        address).
      </description>
    </key>

    <key name="transaction-timeout" datatype="integer"
         required="no">
      <description>
        The maximum amount of time to wait for a transaction to commit
        after acquiring the storage lock, specified in seconds.  If the
        transaction takes too long, the client connection will be closed
        and the transaction aborted.
      </description>
    </key>

    <key name="authentication-protocol" required="no">
      <description>
        The name of the protocol used for authentication.  The
        only protocol provided with ZEO is "digest," but extensions
        may provide other protocols.
      </description>
    </key>

    <key name="authentication-database" required="no">
      <description>
        The path of the database containing authentication credentials.
      </description>
    </key>

    <key name="authentication-realm" required="no">
      <description>
        The authentication realm of the server.  Some authentication
        schemes use a realm to identify the logic set of usernames
        that are accepted by this server.
      </description>
    </key>

  </sectiontype>

</component>


=== ZODB4/src/zodb/zeo/stubs.py 1.7 => 1.7.6.1 ===
--- ZODB4/src/zodb/zeo/stubs.py:1.7	Mon May 19 11:02:51 2003
+++ ZODB4/src/zodb/zeo/stubs.py	Tue Jun 17 17:59:24 2003
@@ -52,7 +52,7 @@
         self.rpc.callAsync('endVerify')
 
     def invalidateTransaction(self, tid, invlist):
-        self.rpc.callAsync('invalidateTransaction', tid, invlist)
+        self.rpc.callAsyncNoPoll('invalidateTransaction', tid, invlist)
 
     def serialnos(self, arg):
         self.rpc.callAsync('serialnos', arg)
@@ -102,6 +102,12 @@
 
     def get_info(self):
         return self.rpc.call('get_info')
+
+    def getAuthProtocol(self):
+        return self.rpc.call('getAuthProtocol')
+    
+    def lastTransaction(self):
+        return self.rpc.call('lastTransaction')
 
     def getInvalidations(self, tid):
         return self.rpc.call('getInvalidations', tid)


=== ZODB4/src/zodb/zeo/server.py 1.12 => 1.12.2.1 ===
--- ZODB4/src/zodb/zeo/server.py:1.12	Sat Jun  7 02:54:23 2003
+++ ZODB4/src/zodb/zeo/server.py	Tue Jun 17 17:59:24 2003
@@ -58,7 +58,11 @@
 
     ClientStorageStubClass = ClientStorageStub
 
-    def __init__(self, server, read_only=0):
+    # A list of extension methods.  A subclass with extra methods
+    # should override.
+    extensions = []
+
+    def __init__(self, server, read_only=0, auth_realm=None):
         self.server = server
         # timeout and stats will be initialized in register()
         self.timeout = None
@@ -73,7 +77,22 @@
         self.verifying = 0
         self.logger = logging.getLogger("ZSS.%d.ZEO" % os.getpid())
         self.log_label = ""
+        self.authenticated = 0
+        self.auth_realm = auth_realm
+        # The authentication protocol may define extra methods.
+        self._extensions = {}
+        for func in self.extensions:
+            self._extensions[func.func_name] = None
+
+    def finish_auth(self, authenticated):
+        if not self.auth_realm:
+            return 1
+        self.authenticated = authenticated
+        return authenticated
 
+    def set_database(self, database):
+        self.database = database
+        
     def notifyConnected(self, conn):
         self.connection = conn # For restart_other() below
         self.client = self.ClientStorageStubClass(conn)
@@ -110,6 +129,7 @@
         """Delegate several methods to the storage"""
         self.versionEmpty = self.storage.versionEmpty
         self.versions = self.storage.versions
+        self.getSerial = self.storage.getSerial
         self.load = self.storage.load
         self.modifiedInVersion = self.storage.modifiedInVersion
         self.getVersion = self.storage.getVersion
@@ -125,9 +145,11 @@
             # can be removed
             pass
         else:
-            for name in fn().keys():
-                if not hasattr(self,name):
-                    setattr(self, name, getattr(self.storage, name))
+            d = fn()
+            self._extensions.update(d)
+            for name in d.keys():
+                assert not hasattr(self, name)
+                setattr(self, name, getattr(self.storage, name))
         self.lastTransaction = self.storage.lastTransaction
 
     def _check_tid(self, tid, exc=None):
@@ -149,6 +171,15 @@
                 return 0
         return 1
 
+    def getAuthProtocol(self):
+        """Return string specifying name of authentication module to use.
+
+        The module name should be auth_%s where %s is auth_protocol."""
+        protocol = self.server.auth_protocol
+        if not protocol or protocol == 'none':
+            return None
+        return protocol
+    
     def register(self, storage_id, read_only):
         """Select the storage that this client will use
 
@@ -173,19 +204,14 @@
                                                                    self)
 
     def get_info(self):
-        return {'name': self.storage.getName(),
-                'extensionMethods': self.getExtensionMethods(),
+        return {"name": self.storage.getName(),
+                "extensionMethods": self.getExtensionMethods(),
                  "implements": [iface.__name__
                                 for iface in providedBy(self.storage)],
                 }
 
     def getExtensionMethods(self):
-        try:
-            e = self.storage.getExtensionMethods
-        except AttributeError:
-            return {}
-        else:
-            return e()
+        return self._extensions
 
     def zeoLoad(self, oid):
         self.stats.loads += 1
@@ -564,7 +590,10 @@
     def __init__(self, addr, storages, read_only=0,
                  invalidation_queue_size=100,
                  transaction_timeout=None,
-                 monitor_address=None):
+                 monitor_address=None,
+                 auth_protocol=None,
+                 auth_filename=None,
+                 auth_realm=None):
         """StorageServer constructor.
 
         This is typically invoked from the start.py script.
@@ -606,6 +635,21 @@
             should listen.  If specified, a monitor server is started.
             The monitor server provides server statistics in a simple
             text format.
+
+        auth_protocol -- The name of the authentication protocol to use.
+            Examples are "digest" and "srp".
+            
+        auth_filename -- The name of the password database filename.
+            It should be in a format compatible with the authentication
+            protocol used; for instance, "sha" and "srp" require different
+            formats.
+            
+            Note that to implement an authentication protocol, a server
+            and client authentication mechanism must be implemented in a
+            auth_* module, which should be stored inside the "auth"
+            subdirectory. This module may also define a DatabaseClass
+            variable that should indicate what database should be used
+            by the authenticator.
         """
 
         self.addr = addr
@@ -621,6 +665,12 @@
         for s in storages.values():
             s._waiting = []
         self.read_only = read_only
+        self.auth_protocol = auth_protocol
+        self.auth_filename = auth_filename
+        self.auth_realm = auth_realm
+        self.database = None
+        if auth_protocol:
+            self._setup_auth(auth_protocol)
         # A list of at most invalidation_queue_size invalidations
         self.invq = []
         self.invq_bound = invalidation_queue_size
@@ -643,6 +693,39 @@
         else:
             self.monitor = None
 
+    def _setup_auth(self, protocol):
+        # Can't be done in global scope, because of cyclic references
+        from ZEO.auth import get_module
+
+        name = self.__class__.__name__
+
+        module = get_module(protocol)
+        if not module:
+            log("%s: no such an auth protocol: %s" % (name, protocol))
+            return
+        
+        storage_class, client, db_class = module
+        
+        if not storage_class or not issubclass(storage_class, ZEOStorage):
+            log(("%s: %s isn't a valid protocol, must have a StorageClass" %
+                 (name, protocol)))
+            self.auth_protocol = None
+            return
+        self.ZEOStorageClass = storage_class
+
+        log("%s: using auth protocol: %s" % (name, protocol))
+        
+        # We create a Database instance here for use with the authenticator
+        # modules. Having one instance allows it to be shared between multiple
+        # storages, avoiding the need to bloat each with a new authenticator
+        # Database that would contain the same info, and also avoiding any
+        # possibly synchronization issues between them.
+        self.database = db_class(self.auth_filename)
+        if self.database.realm != self.auth_realm:
+            raise ValueError("password database realm %r "
+                             "does not match storage realm %r"
+                             % (self.database.realm, self.auth_realm))
+
     def new_connection(self, sock, addr):
         """Internal: factory to create a new connection.
 
@@ -650,8 +733,13 @@
         whenever accept() returns a socket for a new incoming
         connection.
         """
-        z = self.ZEOStorageClass(self, self.read_only)
-        c = self.ManagedServerConnectionClass(sock, addr, z, self)
+        if self.auth_protocol and self.database:
+            zstorage = self.ZEOStorageClass(self, self.read_only,
+                                            auth_realm=self.auth_realm)
+            zstorage.set_database(self.database)
+        else:
+            zstorage = self.ZEOStorageClass(self, self.read_only)
+        c = self.ManagedServerConnectionClass(sock, addr, zstorage, self)
         self.logger.warn("new connection %s: %s", addr, `c`)
         return c
 


=== ZODB4/src/zodb/zeo/interfaces.py 1.3 => 1.3.30.1 ===
--- ZODB4/src/zodb/zeo/interfaces.py:1.3	Tue Feb 25 13:55:05 2003
+++ ZODB4/src/zodb/zeo/interfaces.py	Tue Jun 17 17:59:24 2003
@@ -27,3 +27,5 @@
 class ClientDisconnected(ClientStorageError):
     """The database storage is disconnected from the storage."""
 
+class AuthError(StorageError):
+    """The client provided invalid authentication credentials."""


=== ZODB4/src/zodb/zeo/client.py 1.13 => 1.13.2.1 ===
--- ZODB4/src/zodb/zeo/client.py:1.13	Fri Jun  6 11:24:21 2003
+++ ZODB4/src/zodb/zeo/client.py	Tue Jun 17 17:59:24 2003
@@ -106,7 +106,8 @@
     def __init__(self, addr, storage='1', cache_size=20 * MB,
                  name='', client=None, var=None,
                  min_disconnect_poll=5, max_disconnect_poll=300,
-                 wait=True, read_only=False, read_only_fallback=False):
+                 wait=True, read_only=False, read_only_fallback=False,
+                 username='', password='', realm=None):
 
         """ClientStorage constructor.
 
@@ -161,6 +162,17 @@
             writable storages are available.  Defaults to false.  At
             most one of read_only and read_only_fallback should be
             true.
+
+        username -- string with username to be used when authenticating.
+            These only need to be provided if you are connecting to an
+            authenticated server storage.
+
+        password -- string with plaintext password to be used
+            when authenticated.
+
+        Note that the authentication protocol is defined by the server
+        and is detected by the ClientStorage upon connecting (see
+        testConnection() and doAuth() for details).
         """
 
         self.logger = logging.getLogger("ZCS.%d" % os.getpid())
@@ -202,6 +214,9 @@
         self._conn_is_read_only = 0
         self._storage = storage
         self._read_only_fallback = read_only_fallback
+        self._username = username
+        self._password = password
+        self._realm = realm
         # _server_addr is used by sortKey()
         self._server_addr = None
         self._tfile = None
@@ -236,6 +251,21 @@
         self._oid_lock = threading.Lock()
         self._oids = [] # Object ids retrieved from newObjectIds()
 
+        # load() and tpc_finish() must be serialized to guarantee
+        # that cache modifications from each occur atomically.
+        # It also prevents multiple load calls occuring simultaneously,
+        # which simplifies the cache logic.
+        self._load_lock = threading.Lock()
+        # _load_oid and _load_status are protected by _lock
+        self._load_oid = None
+        self._load_status = None
+
+        # Can't read data in one thread while writing data
+        # (tpc_finish) in another thread.  In general, the lock
+        # must prevent access to the cache while _update_cache
+        # is executing.
+        self._lock = threading.Lock()
+
         t = self._ts = get_timestamp()
         self._serial = `t`
         self._oid = '\0\0\0\0\0\0\0\0'
@@ -330,6 +360,29 @@
         if cn is not None:
             cn.pending()
 
+    def doAuth(self, protocol, stub):
+        if not (self._username and self._password):
+            raise AuthError, "empty username or password"
+
+        module = get_module(protocol)
+        if not module:
+            log2(PROBLEM, "%s: no such an auth protocol: %s" %
+                 (self.__class__.__name__, protocol))
+            return
+
+        storage_class, client, db_class = module
+
+        if not client:
+            log2(PROBLEM,
+                 "%s: %s isn't a valid protocol, must have a Client class" %
+                 (self.__class__.__name__, protocol))
+            raise AuthError, "invalid protocol"
+
+        c = client(stub)
+
+        # Initiate authentication, returns boolean specifying whether OK
+        return c.start(self._username, self._realm, self._password)
+
     def testConnection(self, conn):
         """Internal: test the given connection.
 
@@ -355,6 +408,16 @@
         # XXX Check the protocol version here?
         self._conn_is_read_only = 0
         stub = self.StorageServerStubClass(conn)
+        
+        auth = stub.getAuthProtocol()
+        self.logger.info("Client authentication successful")
+        if auth:
+            if self.doAuth(auth, stub):
+                self.logger.info("Client authentication successful")
+            else:
+                self.logger.error("Authentication failed")
+                raise AuthError, "Authentication failed"
+
         try:
             stub.register(str(self._storage), self._is_read_only)
             return 1
@@ -406,6 +469,12 @@
         if not conn.is_async():
             self.logger.warn("Waiting for cache verification to finish")
             self._wait_sync()
+        self._handle_extensions()
+
+    def _handle_extensions(self):
+        for name in self.getExtensionMethods().keys():
+            if not hasattr(self, name):
+                setattr(self, name, self._server.extensionMethod(name))
 
     def update_interfaces(self):
         # Update what interfaces the instance provides based on the server.
@@ -600,12 +669,6 @@
         """
         return self._server.history(oid, version, length)
 
-    def __getattr__(self, name):
-        if self.getExtensionMethods().has_key(name):
-            return self._server.extensionMethod(name)
-        else:
-            raise AttributeError(name)
-
     def loadSerial(self, oid, serial):
         """Storage API: load a historical revision of an object."""
         return self._server.loadSerial(oid, serial)
@@ -621,14 +684,39 @@
         specified by the given object id and version, if they exist;
         otherwise a KeyError is raised.
         """
-        p = self._cache.load(oid, version)
-        if p:
-            return p
+        self._lock.acquire()    # for atomic processing of invalidations
+        try:
+            pair = self._cache.load(oid, version)
+            if pair:
+                return pair
+        finally:
+            self._lock.release()
+
         if self._server is None:
             raise ClientDisconnected()
-        p, s, v, pv, sv = self._server.zeoLoad(oid)
-        self._cache.checkSize(0)
-        self._cache.store(oid, p, s, v, pv, sv)
+
+        self._load_lock.acquire()
+        try:
+            self._lock.acquire()
+            try:
+                self._load_oid = oid
+                self._load_status = 1
+            finally:
+                self._lock.release()
+
+            p, s, v, pv, sv = self._server.zeoLoad(oid)
+
+            self._lock.acquire()    # for atomic processing of invalidations
+            try:
+                if self._load_status:
+                    self._cache.checkSize(0)
+                    self._cache.store(oid, p, s, v, pv, sv)
+                self._load_oid = None
+            finally:
+                self._lock.release()
+        finally:
+            self._load_lock.release()
+
         if v and version and v == version:
             return pv, sv
         else:
@@ -641,9 +729,13 @@
 
         If no version modified the object, return an empty string.
         """
-        v = self._cache.modifiedInVersion(oid)
-        if v is not None:
-            return v
+        self._lock.acquire()
+        try:
+            v = self._cache.modifiedInVersion(oid)
+            if v is not None:
+                return v
+        finally:
+            self._lock.release()
         return self._server.modifiedInVersion(oid)
 
     def newObjectId(self):
@@ -740,6 +832,7 @@
 
         self._serial = id
         self._seriald.clear()
+        self._tbuf.clear()
         del self._serials[:]
 
     def end_transaction(self):
@@ -779,18 +872,23 @@
         """Storage API: finish a transaction."""
         if transaction is not self._transaction:
             return
+        self._load_lock.acquire()
         try:
-            if f is not None:
-                f()
+            self._lock.acquire()  # for atomic processing of invalidations
+            try:
+                self._update_cache()
+                if f is not None:
+                    f()
+            finally:
+                self._lock.release()
 
             tid = self._server.tpcFinish(self._serial)
+            self._cache.setLastTid(tid)
 
             r = self._check_serials()
             assert r is None or len(r) == 0, "unhandled serialnos: %s" % r
-
-            self._update_cache()
-            self._cache.setLastTid(tid)
         finally:
+            self._load_lock.release()
             self.end_transaction()
 
     def _update_cache(self):
@@ -799,6 +897,13 @@
         This iterates over the objects in the transaction buffer and
         update or invalidate the cache.
         """
+        # Must be called with _lock already acquired.
+
+        # XXX not sure why _update_cache() would be called on
+        # a closed storage.
+        if self._cache is None:
+            return
+
         self._cache.checkSize(self._tbuf.get_size())
         try:
             self._tbuf.begin_iterate()
@@ -892,15 +997,21 @@
         # oid, version pairs.  The DB's invalidate() method expects a
         # dictionary of oids.
 
-        # versions maps version names to dictionary of invalidations
-        versions = {}
-        for oid, version in invs:
-            d = versions.setdefault(version, {})
-            self._cache.invalidate(oid, version=version)
-            d[oid] = 1
-        if self._db is not None:
-            for v, d in versions.items():
-                self._db.invalidate(d, version=v)
+        self._lock.acquire()
+        try:
+            # versions maps version names to dictionary of invalidations
+            versions = {}
+            for oid, version in invs:
+                if oid == self._load_oid:
+                    self._load_status = 0
+                self._cache.invalidate(oid, version=version)
+                versions.setdefault(version, {})[oid] = 1
+
+            if self._db is not None:
+                for v, d in versions.items():
+                    self._db.invalidate(d, version=v)
+        finally:
+            self._lock.release()
 
     def endVerify(self):
         """Server callback to signal end of cache validation."""
@@ -928,7 +1039,7 @@
             self.logger.debug(
                 "Transactional invalidation during cache verification")
             for t in args:
-                self.self._pickler.dump(t)
+                self._pickler.dump(t)
             return
         self._process_invalidations(args)
 




More information about the Zodb-checkins mailing list