Source code for instatools.cache
"""Response caching for testing integration with instatools"""
import os
import pickle
import re
import sqlite3
from contextlib import contextmanager
import requests
import instatools.api
import instatools.instagram.feeds
import instatools.session
[docs]def clear(data_dir):
db_path = os.path.join(data_dir, 'instagram_cache.db')
if os.path.exists(db_path):
os.remove(db_path)
[docs]@contextmanager
def read(data_dir):
db_path = os.path.join(data_dir, 'instagram_cache.db')
if not os.path.exists(data_dir) or not os.path.exists(db_path):
raise FileNotFoundError(
'You must record a request before reading from the cache')
old_request = instatools.session.Session._session_class.request
instatools.session.Session._session_class.request = \
_handle_request(data_dir, old_request)
instatools.api.sleep_between_pages = 0
instatools.instagram.feeds.FeedReader._sleep_between_reads = 0
yield
instatools.session.Session._session_class.request = old_request
[docs]@contextmanager
def record(data_dir):
if not os.path.exists(data_dir):
os.mkdir(data_dir)
cache = DataBaseCache(os.path.join(data_dir, 'instagram_cache.db'))
old_request = requests.Session.request
def _request(self, method, url, **kwargs):
resp = old_request(self, method, url, **kwargs)
cache.set(url, pickle.dumps(resp))
return resp
instatools.session.Session._session_class.request = _request
instatools.api.sleep_between_pages = 0
instatools.instagram.feeds.FeedReader._sleep_between_reads = 0
yield
instatools.session.Session._session_class.request = old_request
[docs]class DataBaseCache(object):
def __init__(self, db_path, default_factories=None):
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
self.cursor.execute('create table if not exists cache(key, value)')
self.conn.commit()
self._cache = {}
self._default_factories = default_factories or {}
for k, v in self.cursor.execute('select * from cache').fetchall():
self._cache[k] = v
def __del__(self):
self.conn.commit()
self.conn.close()
[docs] def get(self, key):
result = self._cache.get(key, None)
if result is not None:
return result
else:
for k, v in self._cache.items():
if k.startswith(key):
return v
for k, factory in self._default_factories.items():
if re.search(k, key) is not None:
if callable(factory):
return factory()
return factory
[docs] def set(self, key, value):
self.cursor.execute('insert into cache values (?, ?)', (key, value))
self.conn.commit()
self._cache[key] = value
[docs] def delete(self, key):
self.cursor.execute('delete from cache where key=?', (key,))
self.conn.commit()
del self._cache[key]
[docs] def clear(self):
self.cursor.execute('drop table cache')
self.cursor.execute('create table cache(key, value)')
self.conn.commit()
def _handle_request(data_dir, request_method):
success_response = requests.Response()
success_response.status_code = 200
success_response._content = b'{"status":"ok"}'
cache = DataBaseCache(
os.path.join(data_dir, 'instagram_cache.db'),
default_factories={
'.*': pickle.dumps(success_response)
})
def _request(self, method, url, **kwargs):
if 'i.instagram.com/api' in url:
result = cache.get(url)
if result is not None:
response = pickle.loads(result)
response.cookies.update({'csrftoken': 'token'})
return response
return request_method(self, method, url, **kwargs)
return _request