diff --git a/src/dyns/controllers.py b/src/dyns/controllers.py index bfeb98b..72ca2ef 100644 --- a/src/dyns/controllers.py +++ b/src/dyns/controllers.py @@ -3,91 +3,161 @@ import json import milla +class BaseController(milla.controllers.Controller): + + def __init__(self): + # allowed_methods must be set on the instance rather than the + # class because of how Milla does attribute copying to the + # partial the router creates. + try: + self.allowed_methods = self.__class__.allowed_methods + except AttributeError: + pass + + def __before__(self, request): + super(BaseController, self).__before__(request) + self.session = model.Session() + + def __after__(self, request): + self.session.rollback() + self.session.bind.dispose() + del self.session + + def index(request): raise milla.HTTPMovedPermanently(location=request.create_href('/zones/')) -@milla.allow('GET', 'HEAD', 'POST') -def all_zones(request): - response = request.ResponseClass() - response.content_type = 'application/json' +class ZoneListController(BaseController): - session = model.Session() - if request.method == 'GET': - zones = map(model.Zone.as_dict, session.query(model.Zone)) + allowed_methods = ('GET', 'HEAD', 'POST') + + def __call__(self, request): + return getattr(self, request.method)(request) + + def GET(self, request): + response = request.ResponseClass() + response.content_type = 'application/json' + zones = map(model.Zone.as_dict, self.session.query(model.Zone)) json.dump(list(zones), response.body_file) - session.rollback() - elif request.method == 'POST': + return response + + HEAD = GET + + def POST(self, request): + response = request.ResponseClass() + response.content_type = None data = json.loads(request.text) zone = model.Zone(**data) - session.add(zone) - session.commit() + self.session.add(zone) + self.session.commit() response.status_int = 201 - return response + response.location = request.create_href_full( + '/zones/{}'.format(zone.name) + ) + return response -@milla.allow('GET', 'HEAD', 'POST', 'PUT', 'DELETE') -def zone(request, name): - response = request.ResponseClass() - response.content_type = 'application/json' +class ZoneController(BaseController): - session = model.Session() - zone = session.query(model.Zone).get(name) - if not zone: - raise milla.HTTPNotFound + allowed_methods = ('GET', 'HEAD', 'POST', 'PUT', 'DELETE') - if request.method == 'GET': + def __call__(self, request, name): + return getattr(self, request.method)(request, name) + + def get_zone(self, name): + zone = self.session.query(model.Zone).get(name) + if not zone: + raise milla.HTTPNotFound + return zone + + def GET(self, request, name): + response = request.ResponseClass() + response.content_type = 'application/json' + zone = self.get_zone(name) zone_d = zone.as_dict() zone_d['records'] = list(map(model.Record.as_dict, zone.records)) json.dump(zone_d, response.body_file) - session.rollback() - elif request.method == 'PUT': + return response + + HEAD = GET + + def PUT(self, request, name): + response = request.ResponseClass() + response.content_type = 'application/json' data = json.loads(request.text) for k, v in data.items(): assert k != 'records' assert hasattr(zone, k) setattr(zone, k, v) - session.commit() - elif request.method == 'POST': + self.session.commit() + return response + + def POST(self, request, name): + response = request.ResponseClass() + response.content_type = None + zone = self.get_zone(name) data = json.loads(request.text) zone_name = data.pop('zone', zone.name) assert zone_name == zone.name record = model.Record(zone=zone_name, **data) - session.add(record) - session.commit() + self.session.add(record) + self.session.commit() response.status_int = 201 response.location = request.create_href_full( '/records/{}'.format(record.id) ) - elif request.method == 'DELETE': - session.delete(zone) - session.commit() - else: - session.rollback() - return response + return response -@milla.allow('GET', 'HEAD', 'PUT', 'DELETE') -def record(request, id): - response = request.ResponseClass() - response.content_type = 'application/json' + def DELETE(self, request, name): + response = request.ResponseClass() + response.content_type = None + zone = self.get_zone(name) + self.session.delete(zone) + self.session.commit() + response.status_int = 204 + return response - session = model.Session() - record = session.query(model.Record).get(id) - if not record: - raise milla.HTTPNotFound - if request.method == 'GET': +class RecordController(BaseController): + + allowed_methods = ('GET', 'HEAD', 'PUT', 'DELETE') + + def __call__(self, request, id): + return getattr(self, request.method)(request, id) + + def get_record(self, id): + record = self.session.query(model.Record).get(id) + if not record: + raise milla.HTTPNotFound + return record + + def GET(self, request, id): + response = request.ResponseClass() + response.content_type = 'application/json' + record = self.get_record(id) json.dump(record.as_dict(), response.body_file) - session.rollback() - elif request.method == 'DELETE': - session.delete(record) - session.commit() - elif request.method == 'PUT': - data =json.loads(request.text) + return response + + HEAD = GET + + def PUT(self, request, id): + response = request.ResponseClass() + response.content_type = 'application/json' + record = self.get_record(id) + data = json.loads(request.text) for k, v in data.items(): + assert k != 'zone' assert hasattr(record, k) setattr(record, k, v) - session.commit() - else: - session.rollback() - return response + self.session.commit() + return response + + def DELETE(self, request, id): + response = request.ResponseClass() + response.content_type = None + record = self.get_record(id) + self.session.delete(record) + self.session.commit() + response.status_int = 204 + return response diff --git a/src/dyns/routes.py b/src/dyns/routes.py index fb04ecf..bc9233e 100644 --- a/src/dyns/routes.py +++ b/src/dyns/routes.py @@ -4,6 +4,6 @@ from milla.dispatch import routing router = routing.Router() router.add_route('/', controllers.index) -router.add_route('/zones/', controllers.all_zones) -router.add_route('/zones/{name}', controllers.zone) -router.add_route('/records/{id}', controllers.record) +router.add_route('/zones/', controllers.ZoneListController()) +router.add_route('/zones/{name}', controllers.ZoneController()) +router.add_route('/records/{id}', controllers.RecordController())