git.fiddlerwoaroof.com
Raw Blame History
def force_unicode(blarg):
	result = blarg
	if isinstance(result, str):
		result = result.decode('utf-8')
	else:
		result = unicode(result)
	return result

class Tag(object):
	template = u'<%(tagname)s%(attributes)s>%(contents)s</%(tagname)s>'
	attribute_template = u'%(attrname)s="%(attrvalue)s"'

	def __init__(self, tag, contents=None, attributes=None):
		self.children = []
		self.tag = tag

		if contents is None: contents = []
		elif not hasattr(contents, '__iter__'):
			contents = [contents]
		self.contents = [(force_unicode(item) if isinstance(item, str) else item) for item in contents]

		if attributes is None: attributes = {}
		self.attributes = {force_unicode(key):force_unicode(attr) for key, attr in attributes.items()}

	def navigate(self, *args):
		if list(args) == []: return [self]
		else:
			head, tail = args[0], args[1:]

			match = {}
			if hasattr(head, '__iter__'):
				match['tag'], match['attributes'] = head
			else:
				match['tag'], match['attributes'] = head, {}

			results = []
			for el in [em for em in self.contents if isinstance(em, self.__class__)]:
				if el.tag == match['tag']:
					if match['attributes'] == {} or all(el.attributes[k]==self.quote_value(match['attributes'][k]) for k in match['attributes']):
						results.extend(el.navigate(*tail))
			return filter(None, results)



	def quote_value(self, value):
		if not isinstance(value, (str, unicode)):
			value = unicode(value)
		return value.replace('"', '\\"')

	def _make_attributes(self):
		result = u''
		if self.attributes != {}:
			result = []
			for attr, value in self.attributes.items():
				result.append(self.attribute_template % dict(attrname=attr, attrvalue=self.quote_value(value)))
			result = u' %s' % u' '.join(result)
		return result

	def _make_contents(self):
		result = []
		for item in self.contents:
			result.append(force_unicode(item))
		return u''.join(result)

	def __repr__(self):
		return '<Tag %r attrs=%r>' % (self.tag, self.attributes)
	def __str__(self):
		return unicode(self).encode('utf-8')

	def __unicode__(self):
		attributes = self._make_attributes()
		contents = self._make_contents()
		return self.template % dict(tagname=self.tag, contents=contents, attributes=attributes)

	def _add(self, item):
		return Tag(self.tag, self.contents+[item], self.attributes)

	def __rshift__(self, other):
		result = self
		if not hasattr(other, '__iter__'):
			other = [other]

		for i in other:
			result = result._add(i)

		return result

	def __rlshift__(self, other):
		print 1
		return self._add(other)

	def __lshift__(self, other):
		if hasattr(other, '_add'):
			return other._add(self)
		else:
			raise TypeError('Cannot surround a tag with text')

	def __enter__(self, *a, **kw):
		return self

	def __exit__(self, *a, **kw):
		pass

	def child(self, tag, **attrs):
		attrs = { key.lower(): value for (key,value) in attrs.items() }
		self.contents.append(Tag(tag, None, attrs))
		return self.contents[-1]

	def text(self, text):
		self.contents.append(text)
		return self

	def attr(self, attr, value):
		self.attributes[attr] = value
		return self

	try:
		from BeautifulSoup import BeautifulSoup
		def pretty(self):
			return self.BeautifulSoup(str(self)).prettify()
	except ImportError:
		pass


import unittest

class TagTest(unittest.TestCase):
	def __init__(self, *args, **kw):
		unittest.TestCase.__init__(self, *args, **kw)
	def test_quoteattrvalue(self):
		a = Tag('test')
		self.assertEqual(a.quote_value('asd"asd'), 'asd\\"asd')
		self.assertEqual(a.quote_value('asd"a"sd'), 'asd\\"a\\"sd')
	def test_print1(self):
		a = Tag('test')
		self.assertEqual(str(a), '<test></test>')
	def test_print2(self):
		a = Tag('test', 'asd')
		self.assertEqual(str(a), '<test>asd</test>')
	def test_print3(self):
		a = Tag('test', None, {'a':1, 'b':'a'})
		self.assertEqual(str(a), '<test a="1" b="a"></test>')
	def test_print4(self):
		a = Tag('test', 'asd', {'a':1, 'b':'a'})
		self.assertEqual(str(a), '<test a="1" b="a">asd</test>')
	def test_print4(self):
		a = Tag('test', ['asd', Tag('test1')], {'a':1, 'b':'a'})
		self.assertEqual(str(a), '<test a="1" b="a">asd<test1></test1></test>')
	def test_print5(self):
		test = str(Tag('a') >> [Tag('b') >> Tag('c'), Tag('d') >> [Tag('e') >> 'asd']])
		model = '<a><b><c></c></b><d><e>asd</e></d></a>'
		self.assertEqual(test, model)
	def test_contextmanager1(self):
		with Tag('a') as a:
			with a.child('b') as b:
				b.attr('height', 20)
				b.text('asd')
			with a.child('c'): pass
			with a.child('d', Class='dog') as d: pass
		self.assertEquals(str(a), '<a><b height="20">asd</b><c></c><d class="dog"></d></a>')

	def test_compose1(self):
		a = Tag('test')
		b = Tag('test1')
		self.assertEqual(str(a >> b), '<test><test1></test1></test>')
	def test_compose2(self):
		a = Tag('test')
		self.assertEqual(str(a >> 'asd'), '<test>asd</test>')
	def test_compose3(self):
		a = Tag('test')
		b = Tag('test1')
		self.assertEqual(str(a << b), '<test1><test></test></test1>')
	def test_compose4(self):
		a = Tag('test')
		self.assertEqual(str('asd' << a), '<test>asd</test>')
	def test_compose4(self):
		a = Tag('test')
		unittest
		self.assertRaises(TypeError, lambda: str(a << 'asd'))
	def test_unicode(self):
		a = Tag(u'üøœå')
		unicode(a)
		str(a)

	def nav_tests(assertion):
		def test_navigate(self):
			with Tag('a') as a:
				with a.child('b') as b:
					b.attr('height', 20)
					b.text('asd')
				with a.child('b') as c:
					with c.child('b') as e: pass
					with c.child('d', height=20) as f: pass
					with c.child('d', height=50) as g: pass
					with c.child('d', height=20, width=10) as h: pass
				with a.child('d', Class='dog') as d: pass

			assertion(self, a,b,c,d,e,f,g,h)
		return test_navigate

	test_navigate1 = nav_tests(lambda self, a,b,c,d,e,f,g,h: self.assertEqual(a.navigate('b'), [b,c]))
	test_navigate2 = nav_tests(lambda self, a,b,c,d,e,f,g,h: self.assertEqual(a.navigate('b', 'b'), [e]))
	test_navigate3 = nav_tests(lambda self, a,b,c,d,e,f,g,h: self.assertEqual(a.navigate('b', 'd'), [f, g, h]))
	test_navigate4 = nav_tests(lambda self, a,b,c,d,e,f,g,h: self.assertEqual(a.navigate('b', ['d', {'height':20}]), [f,h]))




if __name__ == '__main__':
	unittest.main()
	#import fileinput
	#result = Tag('dl')
	#for line in fileinput.input():
	#	line = line.strip().partition(' | ')
	#	result.child('dt').text(line[0])
	#	result.child('dd').text(line[-1])
	#print (Tag('html') >> Tag('body') >> result).pretty()