Files
kiwi-budapest/regiojet.py
2022-03-06 10:34:18 +01:00

251 lines
7.9 KiB
Python

import requests
import json
import pprint
import argparse
from datetime import date, datetime, time, timedelta
import redis
from slugify import slugify
from typing import Optional
from redis import Redis
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Sequence, Column, Integer, String, TEXT, FLOAT
from sqlalchemy.dialects.postgresql import TIMESTAMP
url = "https://brn-ybus-pubapi.sa.cz"
tarrifs_path = "/restapi/consts/tariffs"
location_path = "/restapi/consts/locations"
route_path = "/restapi/routes/search/simple"
redis_host = "redis.pythonweekend.skypicker.com"
#redis_host = "localhost"
def store_dict_in_redis(redis: Redis, key: str, value: dict) -> None:
redis.set(key, json.dumps(value, default=json_serial))
def retrieve_dict(redis: Redis, key: str) -> Optional[dict]:
maybe_value = redis.get(name=key)
if maybe_value is None:
return None
return json.loads(maybe_value)
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, timedelta):
return str(obj)
if isinstance(obj, (datetime, date)):
return obj.isoformat()
raise TypeError("Type %s not serializable" % type(obj))
def search_locations(country,city):
for location in locations:
if country == location['country']:
for _city in location['cities']:
if city == _city['name']:
return _city
def search_connection_regiojet(from_station, to_station, tariff_type, to_location_type, from_location_type, departure):
surname = "jaro"
source = from_station["name"]
destination = to_station["name"]
key1 = F"{surname}:journey"
key2 = slugify(source)
key3 = slugify(destination)
key4 = departure
key = ':'.join((key1, key2 + '_' + key3 + '_' + key4))
journey = retrieve_dict(redisdb, key)
if journey is not None:
journey_db = get_from_db(source, destination, departure)
if journey_db is None:
store_in_db(journey)
return journey
journey = get_from_db(source, destination, departure)
if journey is not None:
return journey
r = requests.get(url + route_path, params={"tariffs_type": tariff_type,
"toLocationId": to_station["id"],
"fromLocationId": from_station["id"],
"fromLocationType": from_location_type,
"toLocationType": to_location_type,
"departureDate": departure})
routes = json.loads(r.content)
routes_ret = []
for route in routes['routes']:
ret = {}
ret["departure_datetime"] = datetime.fromisoformat(route["departureTime"])
ret["arrival_datetime"] = datetime.fromisoformat(route["arrivalTime"])
ret["source"] = from_station["name"]
ret["destination"] = to_station["name"]
ret["source_id"] = from_station["id"]
ret["destination_id"] = to_station["id"]
ret["free_seats"] = route["freeSeatsCount"]
ret["carrier"] = "REGIOJET"
ret["type"] = route["vehicleTypes"][0]
ret["fare"] = {"amount": route["priceFrom"], "currency": "EUR"}
routes_ret.append(ret)
store_dict_in_redis(redisdb, key, routes_ret)
store_in_db(routes_ret)
return routes_ret
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Sequence, Column, Integer, String, TEXT, FLOAT
from sqlalchemy.dialects.postgresql import TIMESTAMP
Base = declarative_base()
class Journey(Base):
# name of the table
__tablename__ = "journeys_jaro"
id = Column(Integer, primary_key=True)
source = Column(TEXT)
destination = Column(TEXT)
departure_datetime = Column(TIMESTAMP)
arrival_datetime = Column(TIMESTAMP)
carrier = Column(TEXT)
vehicle_type = Column(TEXT)
price = Column(FLOAT)
currency = Column(String(3))
def transform_data(data):
ret = {}
pprint.pp(data)
ret["source"] = data["source"]
ret["destination"] = data["destination"]
ret["departure_datetime"] = data["departure_datetime"]
ret["arrival_datetime"] = data["arrival_datetime"]
ret["carrier"] = data["carrier"]
ret["type"] = data["vehicle_type"]
ret["fare"] = {
"amount": data["price"],
"currency": data["currency"]
}
return ret
def get_from_db(source, destination, departure):
Session = sessionmaker(engine)
with Session() as session:
# Combine conditions
result = session.query(Journey).filter(
Journey.source == source,
Journey.destination == destination,
Journey.departure_datetime == departure
).all()
cached_data = []
for r in result:
cached_data.append(transform_data(r.__dict__))
if len(cached_data) > 0:
return cached_data
return None
def store_in_db(journeys):
for _j in journeys:
data = {}
data["source"] = _j["source"]
data["destination"] = _j["destination"]
data["departure_datetime"] = _j["departure_datetime"]
data["arrival_datetime"] = _j["arrival_datetime"]
data["carrier"] = _j["carrier"]
data["vehicle_type"] = _j["type"]
data["price"] = _j["fare"]["amount"]
data["currency"] = _j["fare"]["currency"]
Session = sessionmaker(engine)
with Session() as session:
# Combine conditions
result = session.query(Journey).filter(
Journey.source == data["source"],
Journey.destination == data["destination"],
Journey.departure_datetime == data["departure_datetime"],
Journey.arrival_datetime == data["arrival_datetime"]
).all()
cached_data = []
for r in result:
cached_data.append(transform_data(r.__dict__))
if len(cached_data) > 0:
return cached_data
journey = Journey(
**data
)
# DB connection will be opened and closed automatically
with Session() as session:
# add newly created object to the session
session.add(journey)
# execute in the DB
session.commit()
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
DATABASE_URL = (
"postgresql://jaroslav_drzik:7290b08ac9ca41dba97b02e356fda738@sql.pythonweekend.skypicker.com/pythonweekend"
"?application_name=jaroslav_drzik_local_dev"
)
# echo=True shows debug information
# NullPool closes unused connections immediately
engine = create_engine(
DATABASE_URL,
echo=True,
poolclass=NullPool
)
Base.metadata.create_all(engine)
parser = argparse.ArgumentParser(
description='Search some connection') # use of ArgumentParser against of simple OptionParser
parser.add_argument("origin")
parser.add_argument("destination")
parser.add_argument("departure")
args = parser.parse_args()
redisdb = Redis(host=redis_host, port=6379, db=0, decode_responses=True)
tariffs = retrieve_dict(redisdb,'jaro:REGIOJET:tariffs')
if tariffs is None:
r = requests.get(url + tarrifs_path)
tariffs = json.loads(r.content)
store_dict_in_redis(redisdb, 'jaro:REGIOJET:tariffs', tariffs)
locations = retrieve_dict(redisdb,'jaro:REGIOJET:locations')
if locations is None:
r = requests.get(url + location_path)
locations = json.loads(r.content)
store_dict_in_redis(redisdb, 'jaro:REGIOJET:locations', locations)
city_from = search_locations('Czech Republic', args.origin)
city_to = search_locations('Czech Republic', args.destination)
#pprint.pp(city_from)
#pprint.pp(city_to)
ret = search_connection_regiojet(city_from, city_to, 'REGULAR', 'CITY', 'CITY', args.departure)
print(json.dumps(ret, indent=4, default=json_serial, sort_keys=False))